Skip to content

Commit ccdc5e8

Browse files
authored
Fix DAG reply for AI.TENSORGET op (#793)
* Fix bug in DAG reply for ai.tensorget op when the op was not executed due to an error that occurred beforehand. * Try fixing encoding problem in running tests in GPU with locals
1 parent a0ad9d8 commit ccdc5e8

File tree

3 files changed

+37
-28
lines changed

3 files changed

+37
-28
lines changed

Dockerfile.gpu-test

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@ SHELL ["/bin/bash", "-c"]
2323

2424
ENV NVIDIA_VISIBLE_DEVICES all
2525
ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
26-
26+
ENV LANG=en_US.UTF-8
27+
RUN apt-get update
28+
RUN apt-get install -y locales && \
29+
sed -i -e "s/# $LANG.*/$LANG UTF-8/" /etc/locale.gen && \
30+
dpkg-reconfigure --frontend=noninteractive locales && \
31+
update-locale LANG=$LANG
2732
WORKDIR /build
2833
COPY --from=redis /usr/local/ /usr/local/
2934

src/execution/DAG/dag.c

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -584,8 +584,12 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
584584

585585
case REDISAI_DAG_CMD_TENSORGET: {
586586
rinfo->dagReplyLength++;
587-
RAI_Tensor *t = Dag_GetTensorFromGlobalCtx(rinfo, currentOp->inkeys_indices[0]);
588-
ReplyWithTensor(ctx, currentOp->fmt, t);
587+
if (currentOp->result == -1) {
588+
RedisModule_ReplyWithSimpleString(ctx, "NA");
589+
} else {
590+
RAI_Tensor *t = Dag_GetTensorFromGlobalCtx(rinfo, currentOp->inkeys_indices[0]);
591+
ReplyWithTensor(ctx, currentOp->fmt, t);
592+
}
589593
break;
590594
}
591595

tests/flow/tests_dag_basic.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -230,31 +230,31 @@ def test_dag_with_timeout(env):
230230

231231
env.assertEqual(b'TIMEDOUT', res)
232232

233-
def test_dag_with_timeout(env):
233+
234+
def test_dag_with_error(env):
234235
if not TEST_TF:
235236
return
236-
con = env.getConnection()
237-
batch_size = 2
238-
minbatch_size = 2
239-
timeout = 1
240-
model_name = 'model{1}'
241-
model_pb, input_var, output_var, labels, img = load_mobilenet_v2_test_data()
242237

243-
con.execute_command('AI.MODELSTORE', model_name, 'TF', DEVICE,
244-
'BATCHSIZE', batch_size, 'MINBATCHSIZE', minbatch_size,
245-
'INPUTS', 1, input_var,
246-
'OUTPUTS', 1, output_var,
247-
'BLOB', model_pb)
248-
con.execute_command('AI.TENSORSET', 'input{1}',
249-
'FLOAT', 1, img.shape[1], img.shape[0], img.shape[2],
250-
'BLOB', img.tobytes())
251-
252-
res = con.execute_command('AI.DAGEXECUTE',
253-
'LOAD', '1', 'input{1}',
254-
'TIMEOUT', timeout, '|>',
255-
'AI.MODELEXECUTE', model_name,
256-
'INPUTS', 1, 'input{1}', 'OUTPUTS', 1, 'output{1}',
257-
'|>', 'AI.MODELEXECUTE', model_name,
258-
'INPUTS', 1, 'input{1}', 'OUTPUTS', 1, 'output{1}')
259-
260-
env.assertEqual(b'TIMEDOUT', res)
238+
con = env.getConnection()
239+
tf_model = load_file_content('graph.pb')
240+
ret = con.execute_command('AI.MODELSTORE', 'tf_model{1}', 'TF', DEVICE,
241+
'INPUTS', 2, 'a', 'b',
242+
'OUTPUTS', 1, 'mul',
243+
'BLOB', tf_model)
244+
env.assertEqual(b'OK', ret)
245+
246+
# Run the model from DAG context, where MODELEXECUTE op fails due to dim mismatch in one of the tensors inputs:
247+
# the input tensor 'b' is considered as tensor with dim 2X2X3 initialized with zeros, while the model expects that
248+
# both inputs to node 'mul' will be with dim 2.
249+
ret = con.execute_command('AI.DAGEXECUTE_RO', 'KEYS', 1, '{1}',
250+
'|>', 'AI.TENSORSET', 'a', 'FLOAT', 2, 'VALUES', 2, 3,
251+
'|>', 'AI.TENSORSET', 'b', 'FLOAT', 2, 2, 3,
252+
'|>', 'AI.MODELEXECUTE', 'tf_model{1}', 'INPUTS', 2, 'a', 'b', 'OUTPUTS', 1, 'tD',
253+
'|>', 'AI.TENSORGET', 'tD', 'VALUES')
254+
255+
# Expect that the MODELEXECUTE op will raise an error, and the last TENSORGET op will not be executed
256+
env.assertEqual(ret[0], b'OK')
257+
env.assertEqual(ret[1], b'OK')
258+
env.assertEqual(ret[3], b'NA')
259+
env.assertEqual(type(ret[2]), redis.exceptions.ResponseError)
260+
env.assertTrue(str(ret[2]).find('Incompatible shapes: [2] vs. [2,2,3] \t [[{{node mul}}]]') >= 0)

0 commit comments

Comments
 (0)