Skip to content

Commit 8bc6d88

Browse files
authored
A few tweaks to the JetStream code for better observability and throughput. (#158)
+ Added custom GC config on the serve side, by defult Python does too much GC as we allocate a lot of objects. + Tweaked log level in orchestrator to WARNING so important messages don't hide in server logs. + Added slow TTFT detection and text logging on both server and client side (benchmark_serving as the client). + Fixed timestamp recording on the server side. + Added prefill based throttling on the client side. + Added concurrent active request throttling on the client side.
1 parent 9ca4421 commit 8bc6d88

File tree

4 files changed

+134
-15
lines changed

4 files changed

+134
-15
lines changed

benchmarks/benchmark_serving.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
import asyncio
6363
from dataclasses import dataclass, field
6464
from datetime import datetime
65+
import gc
6566
import json
6667
import random
6768
import time
@@ -107,6 +108,40 @@ def str2bool(v: str) -> bool:
107108
raise ValueError(f"Invalid value '{v}'!")
108109

109110

111+
class AsyncCounter:
112+
"""An counter class for counting and quota management with asycio,
113+
not thread safe. It's safe with asyncio as value changes are done
114+
outside of await statements.
115+
"""
116+
117+
def __init__(self, init_value: int, block_on_zero_seconds=0.002):
118+
"""
119+
Args:
120+
init_value: Initial value for the counter.
121+
block_on_zero_seconds: if greater than 0, the counter will spin when
122+
value hits 0, hence can be used for quota management.
123+
"""
124+
self._init_value = init_value
125+
self._value = init_value
126+
self._block_on_zero_seconds = block_on_zero_seconds
127+
128+
async def inc(self):
129+
self._value += 1
130+
131+
async def dec(self):
132+
while True:
133+
if self._value > 0 or self._block_on_zero_seconds <= 0.0:
134+
self._value -= 1
135+
return
136+
await asyncio.sleep(self._block_on_zero_seconds)
137+
138+
def value(self):
139+
return self._value
140+
141+
def delta(self):
142+
return self._init_value - self._value
143+
144+
110145
@dataclass
111146
class BenchmarkMetrics:
112147
"""Data class to store benchmark metrics."""
@@ -378,13 +413,15 @@ def calculate_metrics(
378413
completed = 0
379414
per_token_latencies = []
380415
ttfts = []
416+
output_sizes = []
381417
for i in range(len(outputs)):
382418
if outputs[i].success:
383419
output_len = len(
384420
outputs[i].generated_token_list
385421
if tokenizer != "test"
386422
else ["Ċ", "Ō", "Ɵ"]
387423
)
424+
output_sizes.append(output_len)
388425
total_output += output_len
389426
total_input += input_requests[i].prompt_len
390427
if output_len == 0:
@@ -397,6 +434,10 @@ def calculate_metrics(
397434
ttfts.append(outputs[i].ttft)
398435
completed += 1
399436

437+
print("Mean output size:", float(np.mean(output_sizes)))
438+
print("Median output size:", float(np.median(output_sizes)))
439+
print("P99 output size:", float(np.percentile(output_sizes, 99)))
440+
400441
metrics = BenchmarkMetrics(
401442
completed=completed,
402443
total_input=total_input,
@@ -416,21 +457,32 @@ def calculate_metrics(
416457

417458

418459
async def grpc_async_request(
419-
api_url: str, request: Any
460+
api_url: str,
461+
request: Any,
462+
prefill_quota: AsyncCounter,
463+
active_req_quota: AsyncCounter,
420464
) -> tuple[list[str], float, float]:
421465
"""Send grpc synchronous request since the current grpc server is sync."""
422466
options = [("grpc.keepalive_timeout_ms", 10000)]
423467
async with grpc.aio.insecure_channel(api_url, options=options) as channel:
424468
stub = jetstream_pb2_grpc.OrchestratorStub(channel)
425-
print("Making request")
426469
ttft = 0
427470
token_list = []
428471
request_start_time = time.perf_counter()
429472
response = stub.Decode(request)
430473
async for resp in response:
431474
if ttft == 0:
475+
await prefill_quota.inc()
476+
432477
ttft = time.perf_counter() - request_start_time
478+
if ttft > 2.0:
479+
print(
480+
datetime.now(),
481+
f"slow TTFT {ttft:.2f}",
482+
prefill_quota.value(),
483+
)
433484
token_list.extend(resp.stream_content.samples[0].token_ids)
485+
await active_req_quota.inc()
434486
latency = time.perf_counter() - request_start_time
435487
return token_list, ttft, latency
436488

@@ -439,22 +491,28 @@ async def send_request(
439491
api_url: str,
440492
tokenizer: Any,
441493
input_request: InputRequest,
494+
prefill_quota: AsyncCounter,
495+
active_req_quota: AsyncCounter,
442496
pbar: tqdm,
443497
) -> RequestFuncOutput:
444498
"""Send the request to JetStream server."""
499+
445500
# Tokenization on client side following MLPerf standard.
446501
token_ids = tokenizer.encode(input_request.prompt)
447502
request = jetstream_pb2.DecodeRequest(
448503
token_content=jetstream_pb2.DecodeRequest.TokenContent(
449504
token_ids=token_ids
450505
),
451506
max_tokens=input_request.output_len,
507+
metadata=jetstream_pb2.DecodeRequest.Metadata(
508+
start_time=time.perf_counter()
509+
),
452510
)
453511
output = RequestFuncOutput()
454512
output.input_request = input_request
455513
output.prompt_len = input_request.prompt_len
456514
generated_token_list, ttft, latency = await grpc_async_request(
457-
api_url, request
515+
api_url, request, prefill_quota, active_req_quota
458516
)
459517
output.ttft = ttft
460518
output.latency = latency
@@ -463,6 +521,12 @@ async def send_request(
463521
output.generated_text = tokenizer.decode(generated_token_list)
464522
output.success = True
465523
if pbar:
524+
pbar.postfix = (
525+
f"#reqs: {active_req_quota.delta()}/"
526+
f"{active_req_quota.value()}; "
527+
f"#prefill: {prefill_quota.delta()}/"
528+
f"{prefill_quota.value()}"
529+
)
466530
pbar.update(1)
467531
return output
468532

@@ -473,6 +537,8 @@ async def benchmark(
473537
input_requests: list[InputRequest],
474538
request_rate: float,
475539
disable_tqdm: bool,
540+
prefill_quota: AsyncCounter,
541+
active_req_quota: AsyncCounter,
476542
):
477543
"""Benchmark the online serving performance."""
478544
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
@@ -482,12 +548,17 @@ async def benchmark(
482548
benchmark_start_time = time.perf_counter()
483549
tasks = []
484550
async for request in get_request(input_requests, request_rate):
551+
await prefill_quota.dec()
552+
await active_req_quota.dec()
553+
485554
tasks.append(
486555
asyncio.create_task(
487556
send_request(
488557
api_url=api_url,
489558
tokenizer=tokenizer,
490559
input_request=request,
560+
prefill_quota=prefill_quota,
561+
active_req_quota=active_req_quota,
491562
pbar=pbar,
492563
)
493564
)
@@ -579,6 +650,9 @@ def main(args: argparse.Namespace):
579650
tokenizer_id = args.tokenizer
580651
use_hf_tokenizer = args.use_hf_tokenizer
581652

653+
prefill_quota = AsyncCounter(init_value=3)
654+
active_req_quota = AsyncCounter(init_value=450)
655+
582656
api_url = f"{args.server}:{args.port}"
583657

584658
tokenizer = get_tokenizer(model_id, tokenizer_id, use_hf_tokenizer)
@@ -621,6 +695,8 @@ def main(args: argparse.Namespace):
621695
input_requests=warmup_requests,
622696
request_rate=args.request_rate,
623697
disable_tqdm=args.disable_tqdm,
698+
prefill_quota=prefill_quota,
699+
active_req_quota=active_req_quota,
624700
)
625701
)
626702
print(f"{args.warmup_mode} warmup completed.")
@@ -636,6 +712,8 @@ def main(args: argparse.Namespace):
636712
input_requests=input_requests,
637713
request_rate=args.request_rate,
638714
disable_tqdm=args.disable_tqdm,
715+
prefill_quota=prefill_quota,
716+
active_req_quota=active_req_quota,
639717
)
640718
)
641719

@@ -836,4 +914,5 @@ def main(args: argparse.Namespace):
836914
)
837915

838916
parsed_args = parser.parse_args()
917+
gc.disable()
839918
main(parsed_args)

jetstream/core/config_lib.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ class ServerConfig:
3939
generate_engine_create_fns: Tuple[CreateEngineFn, ...] = ()
4040
interleaved_engine_create_fns: Tuple[CreateEngineFn, ...] = ()
4141
is_ray_backend: bool = False
42+
# Parameters for customized gc config, increase the numbers here will
43+
# potentially increase memory usage.
44+
gc_gen0_allocs: int = 60000 # default is 700, too frequent sometimes.
45+
gc_gen1_multipler: int = 2 # Make gen1 gc runs less frequent
46+
gc_gen2_multipler: int = 3 # Make gen2 gc runs less frequent
4247

4348

4449
@dataclasses.dataclass

jetstream/core/orchestrator.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
to debug hangs due to bugs in threads (it is easier to debug with live logs).
7575
"""
7676

77+
from datetime import datetime
7778
import dataclasses
7879
import functools
7980
import itertools
@@ -98,10 +99,10 @@
9899
import numpy as np
99100

100101
root = logging.getLogger()
101-
root.setLevel(logging.INFO)
102+
root.setLevel(logging.WARNING)
102103

103104
handler = logging.StreamHandler(sys.stdout)
104-
handler.setLevel(logging.INFO)
105+
handler.setLevel(logging.WARNING)
105106
formatter = logging.Formatter(
106107
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
107108
)
@@ -113,18 +114,25 @@
113114
class ActiveRequestMetadata:
114115
"""Inference request metadata."""
115116

116-
start_time: Optional[float] = None
117+
start_time: float = 0.0
117118

118-
prefill_enqueue_time: Optional[float] = None
119-
prefill_dequeue_time: Optional[float] = None
119+
prefill_enqueue_time: float = 0.0
120+
prefill_dequeue_time: float = 0.0
120121

121-
transfer_enqueue_time: Optional[float] = None
122-
transfer_dequeue_time: Optional[float] = None
122+
transfer_enqueue_time: float = 0.0
123+
transfer_dequeue_time: float = 0.0
123124

124-
generate_enqueue_time: Optional[float] = None
125-
generate_dequeue_time: Optional[float] = None
125+
generate_enqueue_time: float = 0.0
126+
generate_dequeue_time: float = 0.0
126127

127-
complete_time: Optional[float] = None
128+
complete_time: float = 0.0
129+
130+
def stats(self) -> str:
131+
return (
132+
f"{self.prefill_enqueue_time - self.start_time:.2f};"
133+
f"{self.prefill_dequeue_time - self.prefill_enqueue_time:.2f};"
134+
f"{time.perf_counter() - self.prefill_dequeue_time:.2f}"
135+
)
128136

129137

130138
@dataclasses.dataclass
@@ -245,7 +253,7 @@ def __init__(
245253
if generate_params is None:
246254
generate_params = []
247255

248-
logging.info(
256+
logging.warning(
249257
"Initialising driver with %d prefill engines and %d generate engines.",
250258
len(prefill_engines),
251259
len(generate_engines),
@@ -476,6 +484,9 @@ def get_total_concurrent_requests(self) -> int:
476484
)
477485
return total_max_concurrent_decodes
478486

487+
def prefill_backlog_size(self):
488+
return self._prefill_backlog.qsize()
489+
479490
def place_request_on_prefill_queue(self, request: ActiveRequest):
480491
"""Used to place new requests for prefilling and generation."""
481492
# Don't block so we can fail and shed load when the queue is full.
@@ -980,6 +991,8 @@ async def Decode( # pylint: disable=invalid-overridden-method
980991
context: Optional[grpc.aio.ServicerContext] = None,
981992
) -> AsyncIterator[jetstream_pb2.DecodeResponse]:
982993
"""Decode."""
994+
request_start_time = time.perf_counter()
995+
ttft = 0
983996
if context is None:
984997
logging.warning(
985998
"LLM orchestrator is being used in offline test mode, and will not"
@@ -1031,6 +1044,15 @@ async def Decode( # pylint: disable=invalid-overridden-method
10311044
buffered_response_list = []
10321045
async for response in active_request.return_channel:
10331046
response = cast(list[ReturnSample], response)
1047+
if ttft == 0:
1048+
ttft = time.perf_counter() - request_start_time
1049+
if ttft > 2.0:
1050+
print(
1051+
datetime.now(),
1052+
f"Slow TTFT: {ttft:.2f}s,"
1053+
f" stats={active_request.metadata.stats()},"
1054+
f" prefill_qsize={self._driver.prefill_backlog_size()}",
1055+
)
10341056
if is_client_side_tokenization:
10351057
# If is_client_side_tokenization, the client should request with token
10361058
# ids, and the JetStream server will return token ids as response.

jetstream/core/server_lib.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import asyncio
2121
from concurrent import futures
22+
import gc
2223
import logging
2324
import os
2425
import signal
@@ -218,8 +219,20 @@ def run(
218219
# to make sure we can fully saturate the model. Set default minimum to 64.
219220
threads = threads or max(driver.get_total_concurrent_requests(), 64)
220221
jetstream_server = JetStreamServer(driver, threads, port, credentials)
221-
logging.info("Starting server on port %d with %d threads", port, threads)
222222

223+
# Tweak gc config.
224+
# Force a gen 2 collection here.
225+
gc.collect(generation=2)
226+
# Freeze objects currently tracked and ignore them in future gc runs.
227+
gc.freeze()
228+
allocs, gen1, gen2 = gc.get_threshold()
229+
allocs = config.gc_gen0_allocs
230+
gen1 = gen1 * config.gc_gen1_multipler
231+
gen2 = gen2 * config.gc_gen2_multipler
232+
gc.set_threshold(allocs, gen1, gen2)
233+
print("GC tweaked (allocs, gen1, gen2): ", allocs, gen1, gen2)
234+
235+
logging.info("Starting server on port %d with %d threads", port, threads)
223236
jetstream_server.start()
224237

225238
if metrics_collector:

0 commit comments

Comments
 (0)