Skip to content

Commit df6ff98

Browse files
refactor: update xgboost test prediction scripts to accept optional url and headers arguments to be able to test the model deployed in kserve
1 parent bc911d6 commit df6ff98

File tree

2 files changed

+33
-24
lines changed

2 files changed

+33
-24
lines changed

deployment/xgboost_regressor/README.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,18 @@ curl http://localhost:8080/v2/models/my_model
7979
Run the test predictions script to test if your model working correctly with MLServer:
8080

8181
```bash
82-
python test_predictions.py 360950
82+
python test_predictions.py -r 360950
83+
```
84+
85+
## Test predictions on a different server
86+
87+
If using `minikube` to deploy the xgboost model using MLServer behind KServe, you can use the same script to test predictions:
88+
89+
```bash
90+
python test_predictions.py \
91+
-r 360950 \
92+
-u "http://$(minikube ip):$(kubectl get svc istio-ingressgateway --namespace istio-system -o jsonpath='{.spec.ports[?(@.name=="http2")].nodePort}')/v2/models/my_model/infer" \
93+
-H "Host=$(kubectl get inferenceservice xgboost-example --namespace default -o jsonpath='{.status.url}' | cut -d "/" -f 3)"
8394
```
8495

8596
## Other material
Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,17 @@
1-
import sys
1+
import argparse
22

33
import numpy as np
44
import requests
55
import matplotlib.pyplot as plt
66

77

8-
# This is only useful if testing the v1 endpoint
9-
def inference_over_http_v1(data: np.array):
10-
url = "http://localhost:8080/v1/models/my_model:predict"
11-
headers = {"Content-Type": "application/json"}
12-
body = {"instances": data.tolist()}
13-
response = requests.post(url, headers=headers, json=body)
14-
response.raise_for_status()
15-
return response.json()
8+
def inference_over_http(data: np.array, url: str | None = None, headers: dict | None = None):
9+
if url is None:
10+
url = "http://localhost:8080/v2/models/my_model/infer"
11+
if headers is None:
12+
headers = {}
1613

17-
18-
def inference_over_http(data: np.array):
19-
url = "http://localhost:8080/v2/models/my_model/infer"
20-
headers = {"Content-Type": "application/json"}
14+
headers["Content-Type"] = "application/json"
2115
body = {
2216
"inputs": [
2317
{
@@ -34,25 +28,29 @@ def inference_over_http(data: np.array):
3428

3529

3630
if __name__ == "__main__":
37-
try:
38-
run_number = int(sys.argv[1])
39-
except IndexError:
40-
quit("ERROR: Specify a run number while calling this script.")
41-
except ValueError:
42-
quit("ERROR: The provided run number is invalid.")
31+
parser = argparse.ArgumentParser(description="Process run number, optional URL, and optional headers.")
32+
parser.add_argument("-r", "--run_number", type=int, help="The run number (required integer argument).")
33+
parser.add_argument("-u", "--url", type=str, help="Optional URL.")
34+
parser.add_argument("-H", "--headers", type=str, nargs='*', help="Optional headers in key=value format.")
35+
args = parser.parse_args()
36+
37+
if args.run_number < 0:
38+
quit("ERROR: The provided run number must be a non-negative integer.")
39+
if args.headers:
40+
args.headers = dict(header.split("=", 1) for header in args.headers)
4341

4442
data_arr = np.load("../../data/data_unfiltered.npy")
4543
run_arr = np.load("../../data/runs_unfiltered.npy")
4644
unique_runs = np.unique(run_arr)
4745

48-
if run_number not in unique_runs:
46+
if args.run_number not in unique_runs:
4947
quit(f"ERROR: The specified run number is not present in the sample data. Please, choose one of the following: {unique_runs.tolist()}")
5048

5149
# Collect the data to send to the model
52-
target_data = data_arr[run_arr == run_number]
50+
target_data = data_arr[run_arr == args.run_number]
5351

5452
# Do the inference
55-
predictions = inference_over_http(target_data)
53+
predictions = inference_over_http(target_data, args.url, args.headers)
5654

5755
# Parse the output predictions
5856
outputs = predictions["outputs"]
@@ -62,7 +60,7 @@ def inference_over_http(data: np.array):
6260
# Plot
6361
avg_mse = np.array(avg_mse)
6462
plt.figure(figsize=(15, 5))
65-
plt.plot(range(avg_mse.shape[0]), avg_mse, label=f"MSE {run_number}")
63+
plt.plot(range(avg_mse.shape[0]), avg_mse, label=f"MSE {args.run_number}")
6664
plt.legend()
6765
plt.show()
6866
plt.close()

0 commit comments

Comments
 (0)