1- import redis
1+ import numpy as np
22
33from includes import *
44
@@ -47,7 +47,61 @@ def test_run_tflite_model(env):
4747 env .assertEqual (values [0 ], 1 )
4848
4949
50- def test_run_tflite_model_errors (env ):
50+ def test_run_tflite_model_autobatch (env ):
51+ if not TEST_TFLITE :
52+ env .debugPrint ("skipping {} since TEST_TFLITE=0" .format (sys ._getframe ().f_code .co_name ), force = True )
53+ return
54+
55+ con = env .getConnection ()
56+ model_pb = load_file_content ('lite-model_imagenet_mobilenet_v3_small_100_224_classification_5_default_1.tflite' )
57+ _ , _ , _ , img = load_resnet_test_data ()
58+ img = img .astype (np .float32 ) / 255
59+
60+ ret = con .execute_command ('AI.MODELSTORE' , 'm{1}' , 'TFLITE' , 'CPU' ,
61+ 'BATCHSIZE' , 4 , 'MINBATCHSIZE' , 2 ,
62+ 'BLOB' , model_pb )
63+ env .assertEqual (ret , b'OK' )
64+
65+ ret = con .execute_command ('AI.MODELGET' , 'm{1}' , 'META' )
66+ env .assertEqual (len (ret ), 16 )
67+ if DEVICE == "CPU" :
68+ env .assertEqual (ret [1 ], b'TFLITE' )
69+ env .assertEqual (ret [3 ], b'CPU' )
70+
71+ ret = con .execute_command ('AI.TENSORSET' , 'a{1}' ,
72+ 'FLOAT' , 1 , img .shape [1 ], img .shape [0 ], 3 ,
73+ 'BLOB' , img .tobytes ())
74+ env .assertEqual (ret , b'OK' )
75+
76+ ret = con .execute_command ('AI.TENSORSET' , 'b{1}' ,
77+ 'FLOAT' , 1 , img .shape [1 ], img .shape [0 ], 3 ,
78+ 'BLOB' , img .tobytes ())
79+ env .assertEqual (ret , b'OK' )
80+
81+ def run ():
82+ con = env .getConnection ()
83+ con .execute_command ('AI.MODELEXECUTE' , 'm{1}' , 'INPUTS' , 1 ,
84+ 'b{1}' , 'OUTPUTS' , 1 , 'd{1}' )
85+ ensureSlaveSynced (con , env )
86+
87+ t = threading .Thread (target = run )
88+ t .start ()
89+
90+ con .execute_command ('AI.MODELEXECUTE' , 'm{1}' , 'INPUTS' , 1 , 'a{1}' , 'OUTPUTS' , 1 , 'c{1}' )
91+ t .join ()
92+
93+ ensureSlaveSynced (con , env )
94+
95+ values = con .execute_command ('AI.TENSORGET' , 'c{1}' , 'VALUES' )
96+ idx = np .argmax (values )
97+ env .assertEqual (idx , 112 )
98+
99+ values = con .execute_command ('AI.TENSORGET' , 'd{1}' , 'VALUES' )
100+ idx = np .argmax (values )
101+ env .assertEqual (idx , 112 )
102+
103+
104+ def test_run_tflite_errors (env ):
51105 if not TEST_TFLITE :
52106 env .debugPrint ("skipping {} since TEST_TFLITE=0" .format (sys ._getframe ().f_code .co_name ), force = True )
53107 return
@@ -64,13 +118,6 @@ def test_run_tflite_model_errors(env):
64118 check_error_message (env , con , "Failed to load model from buffer" ,
65119 'AI.MODELSTORE' , 'm{1}' , 'TFLITE' , 'CPU' , 'TAG' , 'asdf' , 'BLOB' , wrong_model_pb )
66120
67- # TODO: Autobatch is tricky with TFLITE because TFLITE expects a fixed batch
68- # size. At least we should constrain MINBATCHSIZE according to the
69- # hard-coded dims in the tflite model.
70- check_error_message (env , con , "Auto-batching not supported by the TFLITE backend" ,
71- 'AI.MODELSTORE' , 'm{1}' , 'TFLITE' , 'CPU' ,
72- 'BATCHSIZE' , 2 , 'MINBATCHSIZE' , 2 , 'BLOB' , model_pb )
73-
74121 ret = con .execute_command ('AI.TENSORSET' , 'a{1}' , 'FLOAT' , 1 , 1 , 28 , 28 , 'BLOB' , sample_raw )
75122 env .assertEqual (ret , b'OK' )
76123
@@ -82,6 +129,19 @@ def test_run_tflite_model_errors(env):
82129 check_error_message (env , con , "Number of keys given as INPUTS here does not match model definition" ,
83130 'AI.MODELEXECUTE' , 'm_2{1}' , 'INPUTS' , 3 , 'a{1}' , 'b{1}' , 'c{1}' , 'OUTPUTS' , 1 , 'd{1}' )
84131
132+ model_pb = load_file_content ('lite-model_imagenet_mobilenet_v3_small_100_224_classification_5_default_1.tflite' )
133+ _ , _ , _ , img = load_resnet_test_data ()
134+
135+ ret = con .execute_command ('AI.MODELSTORE' , 'image_net{1}' , 'TFLITE' , 'CPU' , 'BLOB' , model_pb )
136+ env .assertEqual (ret , b'OK' )
137+ ret = con .execute_command ('AI.TENSORSET' , 'dog{1}' , 'UINT8' , 1 , img .shape [1 ], img .shape [0 ], 3 ,
138+ 'BLOB' , img .tobytes ())
139+ env .assertEqual (ret , b'OK' )
140+
141+ # The model expects FLOAT input, but UINT8 tensor is given.
142+ check_error_message (env , con , "Input tensor type doesn't match the type expected by the model definition" ,
143+ 'AI.MODELEXECUTE' , 'image_net{1}' , 'INPUTS' , 1 , 'dog{1}' , 'OUTPUTS' , 1 , 'output{1}' )
144+
85145
86146def test_tflite_modelinfo (env ):
87147 if not TEST_TFLITE :
0 commit comments