Skip to content

Commit c0f8312

Browse files
authored
Add test for benchmark_serving (#172)
* Add test for benchmark_serving Additionally, fix two minor issues: 1. sort import statement of benchmark_serving lexicographically per group 2. Add missing transformer requirments * Exclude benchmark_serving from coverage failure Exclude benchmark_serving from coverage failure because main and tokenizer and random sample part cannot be realistically tested using simple unit test. They can be better test after more refactoring and add more to the newly added test. * Fix exclude path * Exclude eval_accuracy b/c it was not included before
1 parent 8d3b5e9 commit c0f8312

File tree

4 files changed

+144
-8
lines changed

4 files changed

+144
-8
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@ unit-tests:
5151
coverage run -m unittest -v
5252

5353
check-test-coverage:
54-
coverage report -m --omit="jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/external_tokenizers/*" --fail-under=96
54+
coverage report -m --omit="jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/external_tokenizers/*,benchmarks/benchmark_serving.py,benchmarks/eval_accuracy.py" --fail-under=96

benchmarks/benchmark_serving.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,23 +64,21 @@
6464
from datetime import datetime
6565
import gc
6666
import json
67+
import os
6768
import random
6869
import time
6970
from typing import Any, AsyncGenerator, Optional
70-
import os
71-
7271

72+
from benchmarks.eval_accuracy import eval_accuracy
73+
from benchmarks.metrics import CounterMetric, EventMetric
7374
import grpc
74-
from benchmarks.metrics import EventMetric, CounterMetric
7575
from jetstream.core.proto import jetstream_pb2
7676
from jetstream.core.proto import jetstream_pb2_grpc
7777
from jetstream.engine.token_utils import load_vocab
7878
from jetstream.external_tokenizers.llama3 import llama3_tokenizer
7979
import numpy as np
80-
from tqdm.asyncio import tqdm # pytype: disable=pyi-error
8180
import pandas
82-
83-
from eval_accuracy import eval_accuracy
81+
from tqdm.asyncio import tqdm # pytype: disable=pyi-error
8482
from transformers import AutoTokenizer
8583

8684

benchmarks/requirements.in

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
nltk
22
evaluate
33
rouge-score
4-
tqdm
4+
transformers
5+
tqdm
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for benchmarks."""
16+
17+
import asyncio
18+
import unittest
19+
from unittest import mock
20+
from benchmarks import benchmark_serving
21+
from jetstream.core.proto import jetstream_pb2
22+
23+
24+
class TestBenchmarkServing(unittest.IsolatedAsyncioTestCase):
25+
""" "Tests for benchmark_serving.py."""
26+
27+
async def test_benchmark(self):
28+
api_url = "test_url"
29+
tokenizer = mock.MagicMock()
30+
tokenizer.encode = mock.MagicMock(return_value=[1, 2, 3])
31+
tokenizer.decode = mock.MagicMock(return_value="test_decode")
32+
input_requests = [
33+
benchmark_serving.InputRequest(
34+
prompt="test_prompt", prompt_len=3, output_len=5, sample_idx=0
35+
),
36+
benchmark_serving.InputRequest(
37+
prompt="another_prompt", prompt_len=3, output_len=5, sample_idx=0
38+
),
39+
]
40+
request_rate = 0.0
41+
prefill_quota = benchmark_serving.AsyncCounter(1)
42+
active_req_quota = benchmark_serving.AsyncCounter(10)
43+
disable_tqdm = True
44+
45+
async def mocked_decode_response():
46+
"""Mocks decode reponse as an async generator."""
47+
responses = [
48+
jetstream_pb2.DecodeResponse(
49+
stream_content=jetstream_pb2.DecodeResponse.StreamContent(
50+
samples=[
51+
jetstream_pb2.DecodeResponse.StreamContent.Sample(
52+
token_ids=[1]
53+
),
54+
]
55+
)
56+
),
57+
]
58+
59+
for response in responses:
60+
await asyncio.sleep(10) # Introduce a 10-second delay
61+
yield response
62+
63+
def mock_orchestrator_factory(*args, **kwargs):
64+
"""Mocks generation of an orchestrator stub."""
65+
del args, kwargs # Unused.
66+
mock_stub = mock.MagicMock()
67+
mock_stub.Decode.return_value = mocked_decode_response()
68+
return mock_stub
69+
70+
with mock.patch(
71+
"grpc.aio.insecure_channel", new_callable=mock.MagicMock
72+
) as _, mock.patch(
73+
"jetstream.core.proto.jetstream_pb2_grpc.OrchestratorStub",
74+
new_callable=mock.MagicMock,
75+
) as mock_stub:
76+
mock_stub.side_effect = mock_orchestrator_factory
77+
78+
metrics, outputs = await benchmark_serving.benchmark(
79+
api_url,
80+
tokenizer,
81+
input_requests,
82+
request_rate,
83+
disable_tqdm,
84+
prefill_quota,
85+
active_req_quota,
86+
)
87+
88+
self.assertEqual(len(outputs), 2)
89+
self.assertEqual(outputs[0].generated_text, "test_decode")
90+
self.assertTrue(outputs[0].success)
91+
self.assertEqual(metrics["completed"], 2)
92+
93+
def test_calculate_metrics(self):
94+
input_requests = [
95+
benchmark_serving.InputRequest(
96+
prompt="test_prompt", prompt_len=5, output="test", output_len=4
97+
)
98+
]
99+
outputs = [
100+
benchmark_serving.RequestFuncOutput(
101+
input_request=input_requests[0],
102+
generated_text="test",
103+
generated_token_list=[1, 2, 3, 4],
104+
success=True,
105+
latency_sec=0.4,
106+
ttft_sec=0.1,
107+
ttst_sec=0.2,
108+
prompt_len=5,
109+
)
110+
]
111+
112+
tokenizer = mock.MagicMock()
113+
dur_s = 1.0
114+
115+
metrics = benchmark_serving.calculate_metrics(
116+
input_requests, outputs, dur_s, tokenizer
117+
)
118+
119+
self.assertIsInstance(metrics, benchmark_serving.BenchmarkMetrics)
120+
self.assertEqual(metrics.completed, 1)
121+
self.assertEqual(metrics.total_input, 5)
122+
self.assertEqual(metrics.total_output, 4)
123+
124+
def test_str2bool(self):
125+
self.assertTrue(benchmark_serving.str2bool("true"))
126+
self.assertTrue(benchmark_serving.str2bool("1"))
127+
self.assertTrue(benchmark_serving.str2bool("yes"))
128+
self.assertFalse(benchmark_serving.str2bool("false"))
129+
self.assertFalse(benchmark_serving.str2bool("0"))
130+
self.assertFalse(benchmark_serving.str2bool("no"))
131+
132+
with self.assertRaises(ValueError):
133+
benchmark_serving.str2bool("test")
134+
135+
136+
if __name__ == "__main__":
137+
unittest.main()

0 commit comments

Comments
 (0)