1- import sys
1+ import argparse
22
33import numpy as np
44import requests
55import matplotlib .pyplot as plt
66
77
8- def inference_over_http (data : np .array ):
9- url = "http://localhost:8501/v1/models/my_model:predict"
8+ def inference_over_http (data : np .array , url : str | None = None , headers : dict | None = None ):
9+ if url is None :
10+ url = "http://localhost:8501/v1/models/my_model:predict"
11+ if headers is None :
12+ headers = {}
13+
14+ headers ["Content-Type" ] = "application/json"
1015 body = {"instances" : data .tolist ()}
11- response = requests .post (url , json = body )
16+ response = requests .post (url , json = body , headers = headers )
1217 response .raise_for_status ()
1318 return response .json ()
1419
1520
1621if __name__ == "__main__" :
17- try :
18- run_number = int (sys .argv [1 ])
19- except IndexError :
20- quit ("ERROR: Specify a run number while calling this script." )
21- except ValueError :
22- quit ("ERROR: The provided run number is invalid." )
22+ parser = argparse .ArgumentParser (description = "Process run number, optional URL, and optional headers." )
23+ parser .add_argument ("-r" , "--run_number" , type = int , help = "The run number (required integer argument)." )
24+ parser .add_argument ("-u" , "--url" , type = str , help = "Optional URL." )
25+ parser .add_argument ("-H" , "--headers" , type = str , nargs = '*' , help = "Optional headers in key=value format." )
26+ args = parser .parse_args ()
27+
28+ if args .run_number < 0 :
29+ quit ("ERROR: The provided run number must be a non-negative integer." )
30+ if args .headers :
31+ args .headers = dict (header .split ("=" , 1 ) for header in args .headers )
2332
2433 data_arr = np .load ("../../data/data_unfiltered.npy" )
2534 run_arr = np .load ("../../data/runs_unfiltered.npy" )
2635 unique_runs = np .unique (run_arr )
2736
28- if run_number not in unique_runs :
37+ if args . run_number not in unique_runs :
2938 quit (f"ERROR: The specified run number is not present in the sample data. Please, choose one of the following: { unique_runs .tolist ()} " )
3039
3140 # Collect the data to send to the model
32- target_data = data_arr [run_arr == run_number ]
41+ target_data = data_arr [run_arr == args . run_number ]
3342
3443 # Do the inference
35- predictions = inference_over_http (target_data )
44+ predictions = inference_over_http (target_data , args . url , args . headers )
3645
3746 # Parse the output predictions
3847 reconstructed_data = []
@@ -44,7 +53,7 @@ def inference_over_http(data: np.array):
4453 # Plot
4554 avg_mse = np .array (avg_mse )
4655 plt .figure (figsize = (15 , 5 ))
47- plt .plot (range (avg_mse .shape [0 ]), avg_mse , label = f"MSE { run_number } " )
56+ plt .plot (range (avg_mse .shape [0 ]), avg_mse , label = f"MSE { args . run_number } " )
4857 plt .legend ()
4958 plt .show ()
5059 plt .close ()
0 commit comments