3838 of the generation loop at the relevant slot.
3939 - Regardless, it performs a step.
4040 - It takes the sampled tokens, and places them on a 'detokenizing_queue'.
41- 7. Within the detokenizing thread (Prefill and Generate separately) :
41+ 7. Within the detokenizing thread:
4242 - Tokens are detokenized for every 'slot' in a given set of sampled tokens.
4343 - When an end condition is met, the 'slot' integer is returned to the
4444 respective generation queue.
@@ -220,8 +220,7 @@ class Driver:
220220 # Stage 4
221221 # This can be a list because we can pass it as an arg to generate and
222222 # detokenize threads. It is a list of tokens to be detokenized.
223- _prefill_detokenize_backlogs : list [queue .Queue [engine_api .ResultTokens ]] = []
224- _generate_detokenize_backlogs : list [queue .Queue [engine_api .ResultTokens ]] = []
223+ _detokenize_backlogs : list [queue .Queue [engine_api .ResultTokens ]] = []
225224 _generate_slots : list [queue .Queue [int ]] = []
226225 _active_requests : list [queue .Queue [tuple [int , ActiveRequest ]]] = []
227226
@@ -281,11 +280,11 @@ def __init__(
281280 # one of the generate backlogs.
282281 # Interleaved Mode: Max size is 1 to increase the HBM utilization
283282 # during generate.
284- # Disaggregated Mode: Max size is 16 to allow for total 16 prefills to
285- # be enqueued or enqueued while 1 is being transferred.
283+ # Disaggregated Mode: Max size is 4 to allow for 2 prefills to be enqueued
284+ # while 1 transfer is enqueued while 1 is being transferred.
286285 # TODO: Make queue size configurable.
287286 self ._transfer_backlogs = [
288- queue .Queue (1 if self ._interleaved_mode else 16 )
287+ queue .Queue (1 if self ._interleaved_mode else 4 )
289288 for i in range (len (self ._prefill_engines ))
290289 ]
291290 if self ._metrics_collector :
@@ -313,11 +312,10 @@ def __init__(
313312 functools .partial (float , backlog .qsize ())
314313 )
315314 # Stage 4
316- # After prefill and generation, ActiveRequests are placed on the
317- # detokenization backlog for tokens to be sent into each ActiveRequest's
318- # return channel.
319- # We have one of these per prefill / generate engine to simplify
320- # the logic keeping track of which generation engine to replace slots on.
315+ # After generation, ActiveRequests are placed on the detokenization backlog
316+ # for tokens to be sent into each ActiveRequest's return channel.
317+ # We have one of these per generate engine to simplify the logic keeping
318+ # track of which generation engine to replace slots on.
321319 # This is a queue of either - tuple[int, ActiveRequest] which represents our
322320 # active requests, or tuple[int, sample_tokens]. We combine these into one
323321 # queue because it allows us to be somewhat clever with how we do
@@ -332,16 +330,7 @@ def __init__(
332330 # the possibility of race conditions where a slot is made live before the
333331 # tokens are ready and it receives tokens from a different sequence,
334332 # or tokens detokenized before the relevant slot is live.
335-
336- self ._prefill_detokenize_backlogs = [
337- # No need to set maxsize, as transfer queue can
338- # provide the backpressure to the prefill workload
339- # (to avoid the overwhelming prefill).
340- queue .Queue ()
341- for _ in self ._prefill_engines
342- ]
343-
344- self ._generate_detokenize_backlogs = [
333+ self ._detokenize_backlogs = [
345334 # We don't let detokenization accumulate more than 8 steps to avoid
346335 # synchronization issues.
347336 queue .Queue (8 )
@@ -397,25 +386,13 @@ def __init__(
397386 )
398387 for idx in range (len (self ._generate_engines ))
399388 ]
400- self .prefill_detokenize_threads = [
401- JetThread (
402- target = functools .partial (
403- self ._detokenize_thread ,
404- is_prefill = True ,
405- idx = idx ,
406- ),
407- name = f"prefill_detokenize-{ idx } " ,
408- )
409- for idx in range (len (self ._prefill_engines ))
410- ]
411- self .generate_detokenize_threads = [
389+ self .detokenize_threads = [
412390 JetThread (
413391 target = functools .partial (
414392 self ._detokenize_thread ,
415- is_prefill = False ,
416- idx = idx ,
393+ idx ,
417394 ),
418- name = f"generate_detokenize -{ idx } " ,
395+ name = f"detokenize -{ idx } " ,
419396 )
420397 for idx in range (len (self ._generate_engines ))
421398 ]
@@ -424,8 +401,7 @@ def __init__(
424401 self ._prefill_threads ,
425402 self ._transfer_threads ,
426403 self ._generate_threads ,
427- self .prefill_detokenize_threads ,
428- self .generate_detokenize_threads ,
404+ self .detokenize_threads ,
429405 )
430406 )
431407 self .live = True
@@ -444,8 +420,7 @@ def stop(self):
444420 [self ._prefill_backlog ],
445421 self ._transfer_backlogs ,
446422 self ._generate_backlogs .values (),
447- self ._prefill_detokenize_backlogs ,
448- self ._generate_detokenize_backlogs ,
423+ self ._detokenize_backlogs ,
449424 )
450425 )
451426
@@ -561,7 +536,7 @@ def _prefill_thread(self, idx: int):
561536
562537 # put first token to detokenize queue
563538 request .complete = np .zeros ((prefill_engine .samples_per_slot ,), np .bool_ )
564- my_detokenize_backlog = self ._prefill_detokenize_backlogs [idx ]
539+ my_detokenize_backlog = self ._detokenize_backlogs [idx ]
565540 request .metadata .transfer_enqueue_time = time .perf_counter ()
566541 my_detokenize_backlog .put (
567542 (first_token , request , request .metadata .prefill_dequeue_time ),
@@ -657,7 +632,7 @@ def _generate_thread(self, idx: int):
657632 generate_engine = self ._generate_engines [idx ]
658633 my_slots = self ._generate_slots [idx ]
659634 my_generate_backlog = self ._generate_backlogs [idx ]
660- my_detokenize_backlog = self ._generate_detokenize_backlogs [idx ]
635+ my_detokenize_backlog = self ._detokenize_backlogs [idx ]
661636
662637 # Keep track of what step tokens were generated at.
663638 generate_timestep = 0
@@ -787,17 +762,12 @@ def _generate_thread(self, idx: int):
787762 )
788763 time_of_last_generate = time .time ()
789764
790- def _detokenize_thread (self , is_prefill : bool , idx : int ):
765+ def _detokenize_thread (self , idx : int ):
791766 """Detokenize sampled tokens and returns them to the user."""
792767 # One of these per generate engine.
793768 # For all filled my_slots, pop the sampled token onto the relevant
794769 # requests return channel. If it done, place it back onto free slots.
795-
796- if is_prefill :
797- my_detokenize_backlog = self ._prefill_detokenize_backlogs [idx ]
798- else :
799- my_detokenize_backlog = self ._generate_detokenize_backlogs [idx ]
800-
770+ my_detokenize_backlog = self ._detokenize_backlogs [idx ]
801771 my_generate_engine = self ._generate_engines [idx ]
802772 my_slots = self ._generate_slots [idx ]
803773
0 commit comments