1- import sys
1+ import argparse
22
33import numpy as np
44import requests
55import matplotlib .pyplot as plt
66
77
8- def inference_over_http (data : np .array ):
9- url = "http://localhost:8080/v2/models/my_model/infer"
10- headers = {"Content-Type" : "application/json" }
8+ def inference_over_http (data : np .array , url : str | None = None , headers : dict | None = None ):
9+ if url is None :
10+ url = "http://localhost:8080/v2/models/my_model/infer"
11+ if headers is None :
12+ headers = {}
13+
14+ headers ["Content-Type" ] = "application/json"
1115 body = {
1216 "inputs" : [
1317 {
@@ -24,25 +28,29 @@ def inference_over_http(data: np.array):
2428
2529
2630if __name__ == "__main__" :
27- try :
28- run_number = int (sys .argv [1 ])
29- except IndexError :
30- quit ("ERROR: Specify a run number while calling this script." )
31- except ValueError :
32- quit ("ERROR: The provided run number is invalid." )
31+ parser = argparse .ArgumentParser (description = "Process run number, optional URL, and optional headers." )
32+ parser .add_argument ("-r" , "--run_number" , type = int , help = "The run number (required integer argument)." )
33+ parser .add_argument ("-u" , "--url" , type = str , help = "Optional URL." )
34+ parser .add_argument ("-H" , "--headers" , type = str , nargs = '*' , help = "Optional headers in key=value format." )
35+ args = parser .parse_args ()
36+
37+ if args .run_number < 0 :
38+ quit ("ERROR: The provided run number must be a non-negative integer." )
39+ if args .headers :
40+ args .headers = dict (header .split ("=" , 1 ) for header in args .headers )
3341
3442 data_arr = np .load ("../../data/data_unfiltered.npy" )
3543 run_arr = np .load ("../../data/runs_unfiltered.npy" )
3644 unique_runs = np .unique (run_arr )
3745
38- if run_number not in unique_runs :
46+ if args . run_number not in unique_runs :
3947 quit (f"ERROR: The specified run number is not present in the sample data. Please, choose one of the following: { unique_runs .tolist ()} " )
4048
4149 # Collect the data to send to the model
42- target_data = data_arr [run_arr == run_number ]
50+ target_data = data_arr [run_arr == args . run_number ]
4351
4452 # Do the inference
45- predictions = inference_over_http (target_data )
53+ predictions = inference_over_http (target_data , args . url , args . headers )
4654
4755 # Parse the output predictions
4856 outputs = predictions ["outputs" ]
@@ -52,7 +60,7 @@ def inference_over_http(data: np.array):
5260 # Plot
5361 avg_mse = np .array (avg_mse )
5462 plt .figure (figsize = (15 , 5 ))
55- plt .plot (range (avg_mse .shape [0 ]), avg_mse , label = f"MSE { run_number } " )
63+ plt .plot (range (avg_mse .shape [0 ]), avg_mse , label = f"MSE { args . run_number } " )
5664 plt .legend ()
5765 plt .show ()
5866 plt .close ()
0 commit comments