|
| 1 | +import os |
| 2 | + |
| 3 | +from RLTest import Env |
| 4 | +from includes import * |
| 5 | +import shutil |
| 6 | +import argparse |
| 7 | +import signal |
| 8 | +from redis import RedisError |
| 9 | + |
| 10 | +terminate_flag = 0 |
| 11 | +parent_pid = os.getpid() |
| 12 | + |
| 13 | + |
| 14 | +# this should capture user SIGINT signals (such as keyboard ctrl-c). Since we are using multi-processing, |
| 15 | +# this handler will be inherited by all the running processes. Note that every process will get the signal, |
| 16 | +# as all of them are at the same group. |
| 17 | +def handler(signum, frame): |
| 18 | + global terminate_flag |
| 19 | + terminate_flag = 1 |
| 20 | + global parent_pid |
| 21 | + if os.getpid() == parent_pid: # print it only once |
| 22 | + print("\nReceived user interrupt. Shutting down...") |
| 23 | + |
| 24 | + |
| 25 | +def _exit(): |
| 26 | + # remove the logs that were auto generated by redis |
| 27 | + shutil.rmtree('logs', ignore_errors=True) |
| 28 | + print("from exit\n") |
| 29 | + sys.exit(1) |
| 30 | + |
| 31 | + |
| 32 | +def run_benchmark(env, num_runs_mnist, num_runs_inception, num_runs_bert, num_parallel_clients): |
| 33 | + global terminate_flag |
| 34 | + con = get_connection(env, '{1}') |
| 35 | + |
| 36 | + print("Loading ONNX models...") |
| 37 | + model_pb = load_file_content('mnist.onnx') |
| 38 | + sample_raw = load_file_content('one.raw') |
| 39 | + inception_pb = load_file_content('inception-v2-9.onnx') |
| 40 | + _, _, _, _, img = load_mobilenet_v2_test_data() |
| 41 | + bert_pb = load_file_content('bert-base-cased.onnx') |
| 42 | + bert_in_data = np.random.randint(-2, 1, size=(10, 100), dtype=np.int64) |
| 43 | + |
| 44 | + for i in range(50): |
| 45 | + if terminate_flag == 1: |
| 46 | + _exit() |
| 47 | + ret = con.execute_command('AI.MODELSTORE', 'mnist{1}'+str(i), 'ONNX', DEVICE, 'BLOB', model_pb) |
| 48 | + env.assertEqual(ret, b'OK') |
| 49 | + con.execute_command('AI.TENSORSET', 'mnist_in{1}', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw) |
| 50 | + |
| 51 | + for i in range(20): |
| 52 | + if terminate_flag == 1: |
| 53 | + _exit() |
| 54 | + ret = con.execute_command('AI.MODELSTORE', 'inception{1}'+str(i), 'ONNX', DEVICE, 'BLOB', inception_pb) |
| 55 | + env.assertEqual(ret, b'OK') |
| 56 | + |
| 57 | + backends_info = get_info_section(con, 'backends_info') |
| 58 | + print(f'Done. ONNX memory consumption is: {backends_info["ai_onnxruntime_memory"]} bytes') |
| 59 | + |
| 60 | + ret = con.execute_command('AI.TENSORSET', 'inception_in{1}', 'FLOAT', 1, 3, 224, 224, 'BLOB', img.tobytes()) |
| 61 | + env.assertEqual(ret, b'OK') |
| 62 | + ret = con.execute_command('AI.MODELSTORE', 'bert{1}', 'ONNX', DEVICE, 'BLOB', bert_pb) |
| 63 | + env.assertEqual(ret, b'OK') |
| 64 | + ret = con.execute_command('AI.TENSORSET', 'bert_in{1}', 'INT64', 10, 100, 'BLOB', bert_in_data.tobytes()) |
| 65 | + env.assertEqual(ret, b'OK') |
| 66 | + |
| 67 | + def run_parallel_onnx_sessions(con, model, input, num_runs): |
| 68 | + for _ in range(num_runs): |
| 69 | + if terminate_flag == 1: |
| 70 | + return |
| 71 | + # If the user is terminating the benchmark, redis-server will receive a termination signal as well, and |
| 72 | + # RedisError exception will thrown (and caught) |
| 73 | + try: |
| 74 | + if model == 'bert{1}': |
| 75 | + ret = con.execute_command('AI.MODELEXECUTE', model, 'INPUTS', 3, input, input, input, |
| 76 | + 'OUTPUTS', 2, 'res{1}', 'res2{1}') |
| 77 | + else: |
| 78 | + ret = con.execute_command('AI.MODELEXECUTE', model, 'INPUTS', 1, input, 'OUTPUTS', 1, 'res{1}') |
| 79 | + env.assertEqual(ret, b'OK') |
| 80 | + except RedisError: |
| 81 | + return |
| 82 | + |
| 83 | + def run_mnist(): |
| 84 | + run_test_multiproc(env, '{1}', num_parallel_clients, run_parallel_onnx_sessions, |
| 85 | + ('mnist{1}0', 'mnist_in{1}', num_runs_mnist)) |
| 86 | + |
| 87 | + def run_bert(): |
| 88 | + run_test_multiproc(env, '{1}', num_parallel_clients, run_parallel_onnx_sessions, |
| 89 | + ('bert{1}', 'bert_in{1}', num_runs_bert)) |
| 90 | + |
| 91 | + # run only mnist |
| 92 | + mnist_total_requests_count = num_runs_mnist*num_parallel_clients |
| 93 | + print(f'\nRunning {num_runs_mnist} consecutive executions of mnist from {num_parallel_clients} parallel clients...') |
| 94 | + start_time = time.time() |
| 95 | + run_test_multiproc(env, '{1}', num_parallel_clients, run_parallel_onnx_sessions, |
| 96 | + ('mnist{1}0', 'mnist_in{1}', num_runs_mnist)) |
| 97 | + if terminate_flag == 1: |
| 98 | + _exit() |
| 99 | + print(f'Done. Total execution time for {mnist_total_requests_count} requests: {time.time()-start_time} seconds') |
| 100 | + mnist_time = con.execute_command('AI.INFO', 'mnist{1}0')[11] |
| 101 | + print("Average serving time per mnist run session is: {} seconds" |
| 102 | + .format(float(mnist_time)/1000000/mnist_total_requests_count)) |
| 103 | + |
| 104 | + # run only inception |
| 105 | + inception_total_requests_count = num_runs_inception*num_parallel_clients |
| 106 | + print(f'\nRunning {num_runs_inception} consecutive executions of inception from {num_parallel_clients} parallel clients...') |
| 107 | + start_time = time.time() |
| 108 | + run_test_multiproc(env, '{1}', num_parallel_clients, run_parallel_onnx_sessions, |
| 109 | + ('inception{1}0', 'inception_in{1}', num_runs_inception)) |
| 110 | + if terminate_flag == 1: |
| 111 | + _exit() |
| 112 | + print(f'Done. Total execution time for {inception_total_requests_count} requests: {time.time()-start_time} seconds') |
| 113 | + inception_time = con.execute_command('AI.INFO', 'inception{1}0')[11] |
| 114 | + print("Average serving time per inception run session is: {} seconds" |
| 115 | + .format(float(inception_time)/1000000/inception_total_requests_count)) |
| 116 | + |
| 117 | + # run only bert |
| 118 | + bert_total_requests_count = num_runs_bert*num_parallel_clients |
| 119 | + print(f'\nRunning {num_runs_bert} consecutive executions of bert from {num_parallel_clients} parallel clients...') |
| 120 | + start_time = time.time() |
| 121 | + run_test_multiproc(env, '{1}', num_parallel_clients, run_parallel_onnx_sessions, ('bert{1}', 'bert_in{1}', num_runs_bert)) |
| 122 | + if terminate_flag == 1: |
| 123 | + _exit() |
| 124 | + print(f'Done. Total execution time for {bert_total_requests_count} requests: {time.time()-start_time} seconds') |
| 125 | + bert_time = con.execute_command('AI.INFO', 'bert{1}')[11] |
| 126 | + print("Average server time per bert run session is: {} seconds" |
| 127 | + .format(float(bert_time)/1000000/bert_total_requests_count)) |
| 128 | + |
| 129 | + con.execute_command('AI.INFO', 'mnist{1}0', 'RESETSTAT') |
| 130 | + con.execute_command('AI.INFO', 'inception{1}0', 'RESETSTAT') |
| 131 | + con.execute_command('AI.INFO', 'bert{1}', 'RESETSTAT') |
| 132 | + |
| 133 | + # run all 3 models in parallel |
| 134 | + total_requests_count = mnist_total_requests_count+inception_total_requests_count+bert_total_requests_count |
| 135 | + print(f'\nRunning requests for all 3 models from {3*num_parallel_clients} parallel clients...') |
| 136 | + start_time = time.time() |
| 137 | + t = threading.Thread(target=run_mnist) |
| 138 | + t.start() |
| 139 | + t2 = threading.Thread(target=run_bert) |
| 140 | + t2.start() |
| 141 | + run_test_multiproc(env, '{1}', num_parallel_clients, run_parallel_onnx_sessions, |
| 142 | + ('inception{1}0', 'inception_in{1}', num_runs_inception)) |
| 143 | + t.join() |
| 144 | + t2.join() |
| 145 | + if terminate_flag == 1: |
| 146 | + _exit() |
| 147 | + print(f'Done. Total execution time for {total_requests_count} requests: {time.time()-start_time} seconds') |
| 148 | + mnist_info = con.execute_command('AI.INFO', 'mnist{1}0')[11] |
| 149 | + inception_info = con.execute_command('AI.INFO', 'inception{1}0')[11] |
| 150 | + bert_info = con.execute_command('AI.INFO', 'bert{1}')[11] |
| 151 | + total_time = mnist_info+inception_info+bert_info |
| 152 | + print("Average serving time per run session is: {} seconds" |
| 153 | + .format(float(total_time)/1000000/total_requests_count)) |
| 154 | + |
| 155 | + |
| 156 | +if __name__ == '__main__': |
| 157 | + |
| 158 | + # set a handler for user interrupt signal |
| 159 | + signal.signal(signal.SIGINT, handler) |
| 160 | + |
| 161 | + # parse command line arguments |
| 162 | + parser = argparse.ArgumentParser() |
| 163 | + parser.add_argument("--num_threads", default='1', |
| 164 | + help='The number of RedisAI working threads that can execute sessions in parallel') |
| 165 | + parser.add_argument("--num_runs_mnist", type=int, default=500, |
| 166 | + help='The number of requests per client that is running mnist run sessions') |
| 167 | + parser.add_argument("--num_runs_inception", type=int, default=50, |
| 168 | + help='The number of requests per client that is running inception run sessions') |
| 169 | + parser.add_argument("--num_runs_bert", type=int, default=5, |
| 170 | + help='The number of requests per client that is running bert run sessions') |
| 171 | + parser.add_argument("--num_parallel_clients", type=int, default=20, |
| 172 | + help='The number of parallel clients that send consecutive run requests per model') |
| 173 | + args = parser.parse_args() |
| 174 | + |
| 175 | + terminate_flag = 0 |
| 176 | + print(f'Running ONNX benchmark on RedisAI, using {args.num_threads} working threads') |
| 177 | + env = Env(module='install-cpu/redisai.so', |
| 178 | + moduleArgs='MODEL_EXECUTION_TIMEOUT 50000 THREADS_PER_QUEUE '+args.num_threads, logDir='logs') |
| 179 | + |
| 180 | + # If the user is terminating the benchmark, redis-server will receive a termination signal as well, and |
| 181 | + # RedisError exception will thrown (and caught) |
| 182 | + try: |
| 183 | + run_benchmark(env, num_runs_mnist=args.num_runs_mnist, num_runs_inception=args.num_runs_inception, |
| 184 | + num_runs_bert=args.num_runs_bert, num_parallel_clients=args.num_parallel_clients) |
| 185 | + env.stop() |
| 186 | + except RedisError as e: |
| 187 | + pass |
| 188 | + finally: |
| 189 | + # remove the logs that were auto generated by redis |
| 190 | + shutil.rmtree('logs', ignore_errors=True) |
0 commit comments