6262import asyncio
6363from dataclasses import dataclass , field
6464from datetime import datetime
65+ import gc
6566import json
6667import random
6768import time
@@ -107,6 +108,40 @@ def str2bool(v: str) -> bool:
107108 raise ValueError (f"Invalid value '{ v } '!" )
108109
109110
111+ class AsyncCounter :
112+ """An counter class for counting and quota management with asycio,
113+ not thread safe. It's safe with asyncio as value changes are done
114+ outside of await statements.
115+ """
116+
117+ def __init__ (self , init_value : int , block_on_zero_seconds = 0.002 ):
118+ """
119+ Args:
120+ init_value: Initial value for the counter.
121+ block_on_zero_seconds: if greater than 0, the counter will spin when
122+ value hits 0, hence can be used for quota management.
123+ """
124+ self ._init_value = init_value
125+ self ._value = init_value
126+ self ._block_on_zero_seconds = block_on_zero_seconds
127+
128+ async def inc (self ):
129+ self ._value += 1
130+
131+ async def dec (self ):
132+ while True :
133+ if self ._value > 0 or self ._block_on_zero_seconds <= 0.0 :
134+ self ._value -= 1
135+ return
136+ await asyncio .sleep (self ._block_on_zero_seconds )
137+
138+ def value (self ):
139+ return self ._value
140+
141+ def delta (self ):
142+ return self ._init_value - self ._value
143+
144+
110145@dataclass
111146class BenchmarkMetrics :
112147 """Data class to store benchmark metrics."""
@@ -378,13 +413,15 @@ def calculate_metrics(
378413 completed = 0
379414 per_token_latencies = []
380415 ttfts = []
416+ output_sizes = []
381417 for i in range (len (outputs )):
382418 if outputs [i ].success :
383419 output_len = len (
384420 outputs [i ].generated_token_list
385421 if tokenizer != "test"
386422 else ["Ċ" , "Ō" , "Ɵ" ]
387423 )
424+ output_sizes .append (output_len )
388425 total_output += output_len
389426 total_input += input_requests [i ].prompt_len
390427 if output_len == 0 :
@@ -397,6 +434,10 @@ def calculate_metrics(
397434 ttfts .append (outputs [i ].ttft )
398435 completed += 1
399436
437+ print ("Mean output size:" , float (np .mean (output_sizes )))
438+ print ("Median output size:" , float (np .median (output_sizes )))
439+ print ("P99 output size:" , float (np .percentile (output_sizes , 99 )))
440+
400441 metrics = BenchmarkMetrics (
401442 completed = completed ,
402443 total_input = total_input ,
@@ -416,21 +457,32 @@ def calculate_metrics(
416457
417458
418459async def grpc_async_request (
419- api_url : str , request : Any
460+ api_url : str ,
461+ request : Any ,
462+ prefill_quota : AsyncCounter ,
463+ active_req_quota : AsyncCounter ,
420464) -> tuple [list [str ], float , float ]:
421465 """Send grpc synchronous request since the current grpc server is sync."""
422466 options = [("grpc.keepalive_timeout_ms" , 10000 )]
423467 async with grpc .aio .insecure_channel (api_url , options = options ) as channel :
424468 stub = jetstream_pb2_grpc .OrchestratorStub (channel )
425- print ("Making request" )
426469 ttft = 0
427470 token_list = []
428471 request_start_time = time .perf_counter ()
429472 response = stub .Decode (request )
430473 async for resp in response :
431474 if ttft == 0 :
475+ await prefill_quota .inc ()
476+
432477 ttft = time .perf_counter () - request_start_time
478+ if ttft > 2.0 :
479+ print (
480+ datetime .now (),
481+ f"slow TTFT { ttft :.2f} " ,
482+ prefill_quota .value (),
483+ )
433484 token_list .extend (resp .stream_content .samples [0 ].token_ids )
485+ await active_req_quota .inc ()
434486 latency = time .perf_counter () - request_start_time
435487 return token_list , ttft , latency
436488
@@ -439,22 +491,28 @@ async def send_request(
439491 api_url : str ,
440492 tokenizer : Any ,
441493 input_request : InputRequest ,
494+ prefill_quota : AsyncCounter ,
495+ active_req_quota : AsyncCounter ,
442496 pbar : tqdm ,
443497) -> RequestFuncOutput :
444498 """Send the request to JetStream server."""
499+
445500 # Tokenization on client side following MLPerf standard.
446501 token_ids = tokenizer .encode (input_request .prompt )
447502 request = jetstream_pb2 .DecodeRequest (
448503 token_content = jetstream_pb2 .DecodeRequest .TokenContent (
449504 token_ids = token_ids
450505 ),
451506 max_tokens = input_request .output_len ,
507+ metadata = jetstream_pb2 .DecodeRequest .Metadata (
508+ start_time = time .perf_counter ()
509+ ),
452510 )
453511 output = RequestFuncOutput ()
454512 output .input_request = input_request
455513 output .prompt_len = input_request .prompt_len
456514 generated_token_list , ttft , latency = await grpc_async_request (
457- api_url , request
515+ api_url , request , prefill_quota , active_req_quota
458516 )
459517 output .ttft = ttft
460518 output .latency = latency
@@ -463,6 +521,12 @@ async def send_request(
463521 output .generated_text = tokenizer .decode (generated_token_list )
464522 output .success = True
465523 if pbar :
524+ pbar .postfix = (
525+ f"#reqs: { active_req_quota .delta ()} /"
526+ f"{ active_req_quota .value ()} ; "
527+ f"#prefill: { prefill_quota .delta ()} /"
528+ f"{ prefill_quota .value ()} "
529+ )
466530 pbar .update (1 )
467531 return output
468532
@@ -473,6 +537,8 @@ async def benchmark(
473537 input_requests : list [InputRequest ],
474538 request_rate : float ,
475539 disable_tqdm : bool ,
540+ prefill_quota : AsyncCounter ,
541+ active_req_quota : AsyncCounter ,
476542):
477543 """Benchmark the online serving performance."""
478544 pbar = None if disable_tqdm else tqdm (total = len (input_requests ))
@@ -482,12 +548,17 @@ async def benchmark(
482548 benchmark_start_time = time .perf_counter ()
483549 tasks = []
484550 async for request in get_request (input_requests , request_rate ):
551+ await prefill_quota .dec ()
552+ await active_req_quota .dec ()
553+
485554 tasks .append (
486555 asyncio .create_task (
487556 send_request (
488557 api_url = api_url ,
489558 tokenizer = tokenizer ,
490559 input_request = request ,
560+ prefill_quota = prefill_quota ,
561+ active_req_quota = active_req_quota ,
491562 pbar = pbar ,
492563 )
493564 )
@@ -579,6 +650,9 @@ def main(args: argparse.Namespace):
579650 tokenizer_id = args .tokenizer
580651 use_hf_tokenizer = args .use_hf_tokenizer
581652
653+ prefill_quota = AsyncCounter (init_value = 3 )
654+ active_req_quota = AsyncCounter (init_value = 450 )
655+
582656 api_url = f"{ args .server } :{ args .port } "
583657
584658 tokenizer = get_tokenizer (model_id , tokenizer_id , use_hf_tokenizer )
@@ -621,6 +695,8 @@ def main(args: argparse.Namespace):
621695 input_requests = warmup_requests ,
622696 request_rate = args .request_rate ,
623697 disable_tqdm = args .disable_tqdm ,
698+ prefill_quota = prefill_quota ,
699+ active_req_quota = active_req_quota ,
624700 )
625701 )
626702 print (f"{ args .warmup_mode } warmup completed." )
@@ -636,6 +712,8 @@ def main(args: argparse.Namespace):
636712 input_requests = input_requests ,
637713 request_rate = args .request_rate ,
638714 disable_tqdm = args .disable_tqdm ,
715+ prefill_quota = prefill_quota ,
716+ active_req_quota = active_req_quota ,
639717 )
640718 )
641719
@@ -836,4 +914,5 @@ def main(args: argparse.Namespace):
836914 )
837915
838916 parsed_args = parser .parse_args ()
917+ gc .disable ()
839918 main (parsed_args )
0 commit comments