@@ -230,31 +230,31 @@ def test_dag_with_timeout(env):
230230
231231 env .assertEqual (b'TIMEDOUT' , res )
232232
233- def test_dag_with_timeout (env ):
233+
234+ def test_dag_with_error (env ):
234235 if not TEST_TF :
235236 return
236- con = env .getConnection ()
237- batch_size = 2
238- minbatch_size = 2
239- timeout = 1
240- model_name = 'model{1}'
241- model_pb , input_var , output_var , labels , img = load_mobilenet_v2_test_data ()
242237
243- con .execute_command ('AI.MODELSTORE' , model_name , 'TF' , DEVICE ,
244- 'BATCHSIZE' , batch_size , 'MINBATCHSIZE' , minbatch_size ,
245- 'INPUTS' , 1 , input_var ,
246- 'OUTPUTS' , 1 , output_var ,
247- 'BLOB' , model_pb )
248- con .execute_command ('AI.TENSORSET' , 'input{1}' ,
249- 'FLOAT' , 1 , img .shape [1 ], img .shape [0 ], img .shape [2 ],
250- 'BLOB' , img .tobytes ())
251-
252- res = con .execute_command ('AI.DAGEXECUTE' ,
253- 'LOAD' , '1' , 'input{1}' ,
254- 'TIMEOUT' , timeout , '|>' ,
255- 'AI.MODELEXECUTE' , model_name ,
256- 'INPUTS' , 1 , 'input{1}' , 'OUTPUTS' , 1 , 'output{1}' ,
257- '|>' , 'AI.MODELEXECUTE' , model_name ,
258- 'INPUTS' , 1 , 'input{1}' , 'OUTPUTS' , 1 , 'output{1}' )
259-
260- env .assertEqual (b'TIMEDOUT' , res )
238+ con = env .getConnection ()
239+ tf_model = load_file_content ('graph.pb' )
240+ ret = con .execute_command ('AI.MODELSTORE' , 'tf_model{1}' , 'TF' , DEVICE ,
241+ 'INPUTS' , 2 , 'a' , 'b' ,
242+ 'OUTPUTS' , 1 , 'mul' ,
243+ 'BLOB' , tf_model )
244+ env .assertEqual (b'OK' , ret )
245+
246+ # Run the model from DAG context, where MODELEXECUTE op fails due to dim mismatch in one of the tensors inputs:
247+ # the input tensor 'b' is considered as tensor with dim 2X2X3 initialized with zeros, while the model expects that
248+ # both inputs to node 'mul' will be with dim 2.
249+ ret = con .execute_command ('AI.DAGEXECUTE_RO' , 'KEYS' , 1 , '{1}' ,
250+ '|>' , 'AI.TENSORSET' , 'a' , 'FLOAT' , 2 , 'VALUES' , 2 , 3 ,
251+ '|>' , 'AI.TENSORSET' , 'b' , 'FLOAT' , 2 , 2 , 3 ,
252+ '|>' , 'AI.MODELEXECUTE' , 'tf_model{1}' , 'INPUTS' , 2 , 'a' , 'b' , 'OUTPUTS' , 1 , 'tD' ,
253+ '|>' , 'AI.TENSORGET' , 'tD' , 'VALUES' )
254+
255+ # Expect that the MODELEXECUTE op will raise an error, and the last TENSORGET op will not be executed
256+ env .assertEqual (ret [0 ], b'OK' )
257+ env .assertEqual (ret [1 ], b'OK' )
258+ env .assertEqual (ret [3 ], b'NA' )
259+ env .assertEqual (type (ret [2 ]), redis .exceptions .ResponseError )
260+ env .assertTrue (str (ret [2 ]).find ('Incompatible shapes: [2] vs. [2,2,3] \t [[{{node mul}}]]' ) >= 0 )
0 commit comments