1- import sys
1+ import argparse
22
33import numpy as np
44import requests
55import matplotlib .pyplot as plt
66
77
8- # This is only useful if testing the v1 endpoint
9- def inference_over_http_v1 (data : np .array ):
10- url = "http://localhost:8080/v1/models/my_model:predict"
11- headers = {"Content-Type" : "application/json" }
12- body = {"instances" : data .tolist ()}
13- response = requests .post (url , headers = headers , json = body )
14- response .raise_for_status ()
15- return response .json ()
8+ def inference_over_http (data : np .array , url : str | None = None , headers : dict | None = None ):
9+ if url is None :
10+ url = "http://localhost:8080/v2/models/my_model/infer"
11+ if headers is None :
12+ headers = {}
1613
17-
18- def inference_over_http (data : np .array ):
19- url = "http://localhost:8080/v2/models/my_model/infer"
20- headers = {"Content-Type" : "application/json" }
14+ headers ["Content-Type" ] = "application/json"
2115 body = {
2216 "inputs" : [
2317 {
@@ -34,25 +28,29 @@ def inference_over_http(data: np.array):
3428
3529
3630if __name__ == "__main__" :
37- try :
38- run_number = int (sys .argv [1 ])
39- except IndexError :
40- quit ("ERROR: Specify a run number while calling this script." )
41- except ValueError :
42- quit ("ERROR: The provided run number is invalid." )
31+ parser = argparse .ArgumentParser (description = "Process run number, optional URL, and optional headers." )
32+ parser .add_argument ("-r" , "--run_number" , type = int , help = "The run number (required integer argument)." )
33+ parser .add_argument ("-u" , "--url" , type = str , help = "Optional URL." )
34+ parser .add_argument ("-H" , "--headers" , type = str , nargs = '*' , help = "Optional headers in key=value format." )
35+ args = parser .parse_args ()
36+
37+ if args .run_number < 0 :
38+ quit ("ERROR: The provided run number must be a non-negative integer." )
39+ if args .headers :
40+ args .headers = dict (header .split ("=" , 1 ) for header in args .headers )
4341
4442 data_arr = np .load ("../../data/data_unfiltered.npy" )
4543 run_arr = np .load ("../../data/runs_unfiltered.npy" )
4644 unique_runs = np .unique (run_arr )
4745
48- if run_number not in unique_runs :
46+ if args . run_number not in unique_runs :
4947 quit (f"ERROR: The specified run number is not present in the sample data. Please, choose one of the following: { unique_runs .tolist ()} " )
5048
5149 # Collect the data to send to the model
52- target_data = data_arr [run_arr == run_number ]
50+ target_data = data_arr [run_arr == args . run_number ]
5351
5452 # Do the inference
55- predictions = inference_over_http (target_data )
53+ predictions = inference_over_http (target_data , args . url , args . headers )
5654
5755 # Parse the output predictions
5856 outputs = predictions ["outputs" ]
@@ -62,7 +60,7 @@ def inference_over_http(data: np.array):
6260 # Plot
6361 avg_mse = np .array (avg_mse )
6462 plt .figure (figsize = (15 , 5 ))
65- plt .plot (range (avg_mse .shape [0 ]), avg_mse , label = f"MSE { run_number } " )
63+ plt .plot (range (avg_mse .shape [0 ]), avg_mse , label = f"MSE { args . run_number } " )
6664 plt .legend ()
6765 plt .show ()
6866 plt .close ()
0 commit comments