Skip to content

Commit fc20b9b

Browse files
author
DvirDukhan
committed
added DAG ROUTING keyword
1 parent 272a47a commit fc20b9b

File tree

8 files changed

+84
-54
lines changed

8 files changed

+84
-54
lines changed

docs/commands.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ It accepts one or more operations, split by the pipe-forward operator (`|>`).
820820

821821
By default, the DAG execution context is local, meaning that tensor keys appearing in the DAG only live in the scope of the command. That is, setting a tensor with `TENSORSET` will store it local memory and not set it to an actual database key. One can refer to that key in subsequent commands within the DAG, but that key won't be visible outside the DAG or to other clients - no keys are open at the database level.
822822

823-
Loading and persisting tensors from/to keyspace should be done explicitly. The user should specify which key tensors to load from keyspace using the `LOAD` keyword, and which command outputs to persist to the keyspace using the `PERSIST` keyspace. The user can also specify keys in Redis that are going to be accessed for read/write operations (for example, from within `AI.SCRIPTEXECUTE` command), by using the keyword `KEYS`.
823+
Loading and persisting tensors from/to keyspace should be done explicitly. The user should specify which key tensors to load from keyspace using the `LOAD` keyword, and which command outputs to persist to the keyspace using the `PERSIST` keyspace. The user can also specify a tag or key which will assist for the routing of the DAG execution on the right shard in Redis that are going to be accessed for read/write operations (for example, from within `AI.SCRIPTEXECUTE` command), by using the keyword `ROUTING`.
824824

825825
As an example, if `command 1` sets a tensor, it can be referenced by any further command on the chaining.
826826

@@ -832,7 +832,7 @@ A `TIMEOUT t` argument can be specified to cause a request to be removed from th
832832
```
833833
AI.DAGEXECUTE [[LOAD <n> <key-1> <key-2> ... <key-n>] |
834834
[PERSIST <n> <key-1> <key-2> ... <key-n>] |
835-
[KEYS <n> <key-1> <key-2> ... <key-n>]]+
835+
[ROUTING <routing_tag>]]+
836836
[TIMEOUT t]
837837
|> <command> [|> command ...]
838838
```
@@ -841,9 +841,9 @@ _Arguments_
841841

842842
* **LOAD**: denotes the beginning of the input tensors keys' list, followed by the number of keys, and one or more key names
843843
* **PERSIST**: denotes the beginning of the output tensors keys' list, followed by the number of keys, and one or more key names
844-
* **KEYS**: denotes the beginning of keys' list which are used within this command, followed by the number of keys, and one or more key names. Alternately, the keys names list can be replaced with a tag which all of those keys share. Redis will verify that all potential key accesses are done to the right shard.
844+
* **ROUTING**: denotes the a key name or a tag that will assist in routing the dag execution command to the right shard. Redis will verify that all potential key accesses are done to within the target shard.
845845

846-
_While each of the LOAD, PERSIST and KEYS sections are optional (and may appear at most once in the command), the command must contain **at least one** of these 3 keywords._
846+
_While each of the LOAD, PERSIST and ROUTING sections are optional (and may appear at most once in the command), the command must contain **at least one** of these 3 keywords._
847847
* **TIMEOUT**: an optional argument, denotes the time (in ms) after which the client is unblocked and a `TIMEDOUT` string is returned
848848
* **|> command**: the chaining operator, that denotes the beginning of a RedisAI command, followed by one of RedisAI's commands. Command splitting is done by the presence of another `|>`. The supported commands are:
849849
* `AI.TENSORSET`
@@ -873,17 +873,17 @@ redis> AI.DAGEXECUTE PERSIST 1 predictions{tag} |>
873873
1) OK
874874
2) OK
875875
3) 1) FLOAT
876-
2) 1) (integer) 2
877-
2) (integer) 2
878-
3) "\x00\x00\x80?\x00\x00\x00@\x00\x00@@\x00\x00\x80@"
876+
1) 1) (integer) 2
877+
1) (integer) 2
878+
2) "\x00\x00\x80?\x00\x00\x00@\x00\x00@@\x00\x00\x80@"
879879
```
880880

881881
A common pattern is enqueuing multiple SCRIPTEXECUTE and MODELEXECUTE commands within a DAG. The following example uses ResNet-50,to classify images into 1000 object categories. Given that our input tensor contains each color represented as a 8-bit integer and that neural networks usually work with floating-point tensors as their input we need to cast a tensor to floating-point and normalize the values of the pixels - for that we will use `pre_process_3ch` function.
882882

883883
To optimize the classification process we can use a post process script to return only the category position with the maximum classification - for that we will use `post_process` script. Using the DAG capabilities we've removed the necessity of storing the intermediate tensors in the keyspace. You can even run the entire process without storing the output tensor, as follows:
884884

885885
```
886-
redis> AI.DAGEXECUTE KEYS 1 {tag} |>
886+
redis> AI.DAGEXECUTE ROUTING {tag} |>
887887
AI.TENSORSET image UINT8 224 224 3 BLOB b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00....' |>
888888
AI.SCRIPTEXECUTE imagenet_script{tag} pre_process_3ch INPUTS 1 image OUTPUTS 1 temp_key1 |>
889889
AI.MODELEXECUTE imagenet_model{tag} INPUTS 1 temp_key1 OUTPUTS 1 temp_key2 |>

src/execution/DAG/dag.c

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -667,8 +667,7 @@ int RedisAI_DagExecute_IsKeysPositionRequest_ReportKeys(RedisModuleCtx *ctx,
667667
size_t argpos = 1;
668668
while (argpos < argc) {
669669
const char *arg_string = RedisModule_StringPtrLen(argv[argpos++], NULL);
670-
if (!strcasecmp(arg_string, "LOAD") || !strcasecmp(arg_string, "PERSIST") ||
671-
!strcasecmp(arg_string, "KEYS")) {
670+
if (!strcasecmp(arg_string, "LOAD") || !strcasecmp(arg_string, "PERSIST")) {
672671
if (argpos >= argc) {
673672
return REDISMODULE_ERR;
674673
}
@@ -684,6 +683,11 @@ int RedisAI_DagExecute_IsKeysPositionRequest_ReportKeys(RedisModuleCtx *ctx,
684683
for (; argpos < last_argpos; argpos++) {
685684
RedisModule_KeyAtPos(ctx, argpos);
686685
}
686+
} else if (!strcasecmp(arg_string, "ROUTING")) {
687+
if (argpos >= argc) {
688+
return REDISMODULE_ERR;
689+
}
690+
RedisModule_KeyAtPos(ctx, argpos++);
687691
} else if (!strcasecmp(arg_string, "AI.MODELEXECUTE") ||
688692
!strcasecmp(arg_string, "AI.SCRIPTEXECUTE")) {
689693
if (argpos >= argc) {

src/execution/parsing/dag_parser.c

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ int DAGInitialParsing(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleSt
245245
bool load_complete = false;
246246
bool persist_complete = false;
247247
bool timeout_complete = false;
248-
bool keys_complete = false;
248+
bool routing_complete = false;
249249

250250
// The first arg is "AI.DAGEXECUTE(_RO) (or deprecated AI.DAGRUN(_RO))", so we go over from the
251251
// next arg.
@@ -280,13 +280,20 @@ int DAGInitialParsing(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleSt
280280
persist_complete = true;
281281
continue;
282282
}
283-
if (!strcasecmp(arg_string, "KEYS") && !keys_complete && chainingOpCount == 0) {
284-
const int parse_result =
285-
ValidateKeysArgs(ctx, &argv[arg_pos], argc - arg_pos, rinfo->err);
286-
if (parse_result <= 0)
283+
if (!strcasecmp(arg_string, "ROUTING") && !routing_complete && chainingOpCount == 0) {
284+
arg_pos++;
285+
if (arg_pos == argc) {
286+
RAI_SetError(rinfo->err, RAI_EDAGBUILDER, "ERR Missing ROUTING value");
287287
return REDISMODULE_ERR;
288-
arg_pos += parse_result;
289-
keys_complete = true;
288+
}
289+
if (!VerifyKeyInThisShard(ctx, argv[arg_pos++])) {
290+
RAI_SetError(
291+
rinfo->err, RAI_EDAGBUILDER,
292+
"ERR ROUTING value specified in the command hash to slot which does not "
293+
"belong to the current shard");
294+
return REDISMODULE_ERR;
295+
}
296+
routing_complete = true;
290297
continue;
291298
}
292299
if (!strcasecmp(arg_string, "TIMEOUT") && !timeout_complete && chainingOpCount == 0) {
@@ -318,10 +325,10 @@ int DAGInitialParsing(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleSt
318325
// commands).
319326
if (!strncasecmp(RedisModule_StringPtrLen(argv[0], NULL), "AI.DAGEXECUTE",
320327
strlen("AI.DAGEXECUTE"))) {
321-
if (!load_complete && !persist_complete && !keys_complete) {
328+
if (!load_complete && !persist_complete && !routing_complete) {
322329
RAI_SetError(rinfo->err, RAI_EDAGBUILDER,
323330
"ERR AI.DAGEXECUTE and AI.DAGEXECUTE_RO commands must "
324-
"contain at least one out of KEYS, LOAD, PERSIST keywords");
331+
"contain at least one out of ROUTING, LOAD, PERSIST keywords");
325332
return REDISMODULE_ERR;
326333
}
327334
}
@@ -335,9 +342,9 @@ int DAGInitialParsing(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleSt
335342
int ParseDAGExecuteCommand(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleString **argv,
336343
int argc, bool dag_ro) {
337344

338-
// The minimal command is of the form: AI.DAGEXECUTE(_RO) KEYS/LOAD/PERSIST 1 <key> |>
345+
// The minimal command is of the form: AI.DAGEXECUTE(_RO) ROUTING/LOAD/PERSIST 1 <key> |>
339346
// AI.TENSORGET <key>
340-
if (argc < 7) {
347+
if (argc < 6) {
341348
if (dag_ro) {
342349
RAI_SetError(rinfo->err, RAI_EDAGBUILDER,
343350
"ERR missing arguments for 'AI.DAGEXECUTE_RO' command");

src/execution/parsing/script_commands_parser.c

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,17 +147,25 @@ static int _ScriptExecuteCommand_ParseCommand(RedisModuleCtx *ctx, RedisModuleSt
147147
if (keysRequired) {
148148
const char *arg_string = RedisModule_StringPtrLen(argv[argpos], NULL);
149149
if (!strcasecmp(arg_string, "KEYS")) {
150-
keysDone = true;
151150
argpos++;
151+
keysDone = true;
152152
if (_ScriptExecuteCommand_ParseKeys(ctx, argv, argc, &argpos, error, sctx) ==
153153
REDISMODULE_ERR) {
154154
return REDISMODULE_ERR;
155155
}
156+
} else if (!strcasecmp(arg_string, "INPUTS")) {
157+
argpos++;
158+
inputsDone = true;
159+
if (_ScriptExecuteCommand_ParseInputs(ctx, argv, argc, &argpos, error, inputs) ==
160+
REDISMODULE_ERR) {
161+
return REDISMODULE_ERR;
162+
}
156163
}
157-
// argv[3] is not KEYS in AI.SCRIPTEXECUTE command (i.e., not in a DAG).
164+
// argv[3] is not KEYS or INPUTS in AI.SCRIPTEXECUTE command (i.e., not in a DAG).
158165
else {
159-
RAI_SetError(error, RAI_ESCRIPTRUN,
160-
"ERR KEYS scope must be provided first for AI.SCRIPTEXECUTE command");
166+
RAI_SetError(
167+
error, RAI_ESCRIPTRUN,
168+
"ERR KEYS or INPUTS scope must be provided first for AI.SCRIPTEXECUTE command");
161169
return REDISMODULE_ERR;
162170
}
163171
}

tests/flow/tests_commands.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,9 @@ def test_keys_syntax(env):
161161
"of given arguments",
162162
"AI.SCRIPTEXECUTE script{1} bar KEYS 2 key{1}")
163163

164-
# ERR KEYS missing in AI.SCRIPTEXECUTE
165-
check_error_message(env, con, "KEYS scope must be provided first for AI.SCRIPTEXECUTE command",
166-
"AI.SCRIPTEXECUTE script{1} bar INPUTS 2 a{1} a{1}")
167-
168-
# # ERR KEYS section in an inner AI.SCRIPTEXEUTE command within a DAG is not allowed.
169-
# check_error_message(env, con, "Already encountered KEYS scope in current command",
170-
# "AI.DAGEXECUTE KEYS 1 a{1} |> AI.SCRIPTEXECUTE script{1} bar KEYS 1 a{1}")
164+
# ERR KEYS or INPUTS missing in AI.SCRIPTEXECUTE
165+
check_error_message(env, con, "KEYS or INPUTS scope must be provided first for AI.SCRIPTEXECUTE command",
166+
"AI.SCRIPTEXECUTE script{1} bar OUTPUTS 2 a{1} a{1}")
171167

172168

173169
def test_scriptstore(env):
@@ -293,7 +289,7 @@ def test_pytorch_scriptexecute_errors(env):
293289

294290
check_error_message(env, con, "Invalid arguments provided to AI.SCRIPTEXECUTE", 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'KEYS', 1, '{1}', 'ARGS')
295291

296-
check_error_message(env, con, "KEYS scope must be provided first for AI.SCRIPTEXECUTE command", 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'INPUTS', 'OUTPUTS')
292+
check_error_message(env, con, "Invalid argument for inputs count in AI.SCRIPTEXECUTE", 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'INPUTS', 'OUTPUTS')
297293

298294
check_error_message(env, con, "Invalid value for TIMEOUT",'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'KEYS', 1, '{1}', 'INPUTS', 2, 'a{1}', 'b{1}', 'OUTPUTS', 1, 'c{1}', 'TIMEOUT', 'TIMEOUT')
299295

tests/flow/tests_dag.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -388,15 +388,15 @@ def test_dagexecute_modelexecute_multidevice_resnet(env):
388388
ensureSlaveSynced(con, env)
389389

390390
check_error_message(env, con, "INPUT key cannot be found in DAG",
391-
'AI.DAGEXECUTE', 'KEYS', '1', image_key, '|>', 'AI.SCRIPTEXECUTE', script_name, 'pre_process_3ch',
391+
'AI.DAGEXECUTE', 'ROUTING', image_key, '|>', 'AI.SCRIPTEXECUTE', script_name, 'pre_process_3ch',
392392
'INPUTS', 1, image_key, 'OUTPUTS', 1, temp_key1)
393393

394394
check_error_message(env, con, "INPUT key cannot be found in DAG",
395-
'AI.DAGEXECUTE', 'KEYS', '1', image_key, '|>', 'AI.MODELEXECUTE', model_name_0,
395+
'AI.DAGEXECUTE', 'ROUTING', image_key, '|>', 'AI.MODELEXECUTE', model_name_0,
396396
'INPUTS', 1, image_key, 'OUTPUTS', 1, temp_key1)
397397

398398
ret = con.execute_command('AI.DAGEXECUTE',
399-
'KEYS', 1, '{1}','|>',
399+
'ROUTING', '{1}','|>',
400400
'AI.TENSORSET', image_key, 'UINT8', img.shape[1], img.shape[0], 3, 'BLOB', img.tobytes(),'|>',
401401
'AI.SCRIPTEXECUTE', script_name, 'wrong_fn',
402402
'INPUTS', 1, image_key,
@@ -406,7 +406,7 @@ def test_dagexecute_modelexecute_multidevice_resnet(env):
406406
env.assertEquals("Function does not exist: wrong_fn", ret[1].__str__())
407407

408408
check_error_message(env, con, "Number of keys given as INPUTS here does not match model definition",
409-
'AI.DAGEXECUTE', 'KEYS', 1, '{1}',
409+
'AI.DAGEXECUTE', 'ROUTING', '{1}',
410410
'|>', 'AI.TENSORSET', image_key, 'UINT8', img.shape[1], img.shape[0], 3, 'BLOB', img.tobytes(),
411411
'|>',
412412
'AI.SCRIPTEXECUTE', script_name, 'pre_process_3ch',

tests/flow/tests_dag_basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_dag_load(env):
1919
def test_dag_local_tensorset(env):
2020
con = env.getConnection()
2121

22-
command = "AI.DAGEXECUTE KEYS 1 {1} |> " \
22+
command = "AI.DAGEXECUTE ROUTING {1} |> " \
2323
"AI.TENSORSET volatile_tensor1 FLOAT 1 2 VALUES 5 10 |> " \
2424
"AI.TENSORSET volatile_tensor2 FLOAT 1 2 VALUES 5 10 "
2525

@@ -34,7 +34,7 @@ def test_dag_local_tensorset(env):
3434
def test_dagro_local_tensorset(env):
3535
con = env.getConnection()
3636

37-
command = "AI.DAGEXECUTE_RO KEYS 1 {some_tag} |> " \
37+
command = "AI.DAGEXECUTE_RO ROUTING {some_tag} |> " \
3838
"AI.TENSORSET volatile_tensor1 FLOAT 1 2 VALUES 5 10 |> " \
3939
"AI.TENSORSET volatile_tensor2 FLOAT 1 2 VALUES 5 10 "
4040

@@ -246,7 +246,7 @@ def test_dag_with_error(env):
246246
# Run the model from DAG context, where MODELEXECUTE op fails due to dim mismatch in one of the tensors inputs:
247247
# the input tensor 'b' is considered as tensor with dim 2X2X3 initialized with zeros, while the model expects that
248248
# both inputs to node 'mul' will be with dim 2.
249-
ret = con.execute_command('AI.DAGEXECUTE_RO', 'KEYS', 1, '{1}',
249+
ret = con.execute_command('AI.DAGEXECUTE_RO', 'ROUTING', '{1}',
250250
'|>', 'AI.TENSORSET', 'a', 'FLOAT', 2, 'VALUES', 2, 3,
251251
'|>', 'AI.TENSORSET', 'b', 'FLOAT', 2, 2, 3,
252252
'|>', 'AI.MODELEXECUTE', 'tf_model{1}', 'INPUTS', 2, 'a', 'b', 'OUTPUTS', 1, 'tD',

0 commit comments

Comments
 (0)