Skip to content

Commit d0f6f32

Browse files
reverted separate prefill and decode detokenize queue (#176)
1 parent bb41033 commit d0f6f32

File tree

1 file changed

+19
-49
lines changed

1 file changed

+19
-49
lines changed

jetstream/core/orchestrator.py

Lines changed: 19 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
of the generation loop at the relevant slot.
3939
- Regardless, it performs a step.
4040
- It takes the sampled tokens, and places them on a 'detokenizing_queue'.
41-
7. Within the detokenizing thread (Prefill and Generate separately):
41+
7. Within the detokenizing thread:
4242
- Tokens are detokenized for every 'slot' in a given set of sampled tokens.
4343
- When an end condition is met, the 'slot' integer is returned to the
4444
respective generation queue.
@@ -220,8 +220,7 @@ class Driver:
220220
# Stage 4
221221
# This can be a list because we can pass it as an arg to generate and
222222
# detokenize threads. It is a list of tokens to be detokenized.
223-
_prefill_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = []
224-
_generate_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = []
223+
_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = []
225224
_generate_slots: list[queue.Queue[int]] = []
226225
_active_requests: list[queue.Queue[tuple[int, ActiveRequest]]] = []
227226

@@ -281,11 +280,11 @@ def __init__(
281280
# one of the generate backlogs.
282281
# Interleaved Mode: Max size is 1 to increase the HBM utilization
283282
# during generate.
284-
# Disaggregated Mode: Max size is 16 to allow for total 16 prefills to
285-
# be enqueued or enqueued while 1 is being transferred.
283+
# Disaggregated Mode: Max size is 4 to allow for 2 prefills to be enqueued
284+
# while 1 transfer is enqueued while 1 is being transferred.
286285
# TODO: Make queue size configurable.
287286
self._transfer_backlogs = [
288-
queue.Queue(1 if self._interleaved_mode else 16)
287+
queue.Queue(1 if self._interleaved_mode else 4)
289288
for i in range(len(self._prefill_engines))
290289
]
291290
if self._metrics_collector:
@@ -313,11 +312,10 @@ def __init__(
313312
functools.partial(float, backlog.qsize())
314313
)
315314
# Stage 4
316-
# After prefill and generation, ActiveRequests are placed on the
317-
# detokenization backlog for tokens to be sent into each ActiveRequest's
318-
# return channel.
319-
# We have one of these per prefill / generate engine to simplify
320-
# the logic keeping track of which generation engine to replace slots on.
315+
# After generation, ActiveRequests are placed on the detokenization backlog
316+
# for tokens to be sent into each ActiveRequest's return channel.
317+
# We have one of these per generate engine to simplify the logic keeping
318+
# track of which generation engine to replace slots on.
321319
# This is a queue of either - tuple[int, ActiveRequest] which represents our
322320
# active requests, or tuple[int, sample_tokens]. We combine these into one
323321
# queue because it allows us to be somewhat clever with how we do
@@ -332,16 +330,7 @@ def __init__(
332330
# the possibility of race conditions where a slot is made live before the
333331
# tokens are ready and it receives tokens from a different sequence,
334332
# or tokens detokenized before the relevant slot is live.
335-
336-
self._prefill_detokenize_backlogs = [
337-
# No need to set maxsize, as transfer queue can
338-
# provide the backpressure to the prefill workload
339-
# (to avoid the overwhelming prefill).
340-
queue.Queue()
341-
for _ in self._prefill_engines
342-
]
343-
344-
self._generate_detokenize_backlogs = [
333+
self._detokenize_backlogs = [
345334
# We don't let detokenization accumulate more than 8 steps to avoid
346335
# synchronization issues.
347336
queue.Queue(8)
@@ -397,25 +386,13 @@ def __init__(
397386
)
398387
for idx in range(len(self._generate_engines))
399388
]
400-
self.prefill_detokenize_threads = [
401-
JetThread(
402-
target=functools.partial(
403-
self._detokenize_thread,
404-
is_prefill=True,
405-
idx=idx,
406-
),
407-
name=f"prefill_detokenize-{idx}",
408-
)
409-
for idx in range(len(self._prefill_engines))
410-
]
411-
self.generate_detokenize_threads = [
389+
self.detokenize_threads = [
412390
JetThread(
413391
target=functools.partial(
414392
self._detokenize_thread,
415-
is_prefill=False,
416-
idx=idx,
393+
idx,
417394
),
418-
name=f"generate_detokenize-{idx}",
395+
name=f"detokenize-{idx}",
419396
)
420397
for idx in range(len(self._generate_engines))
421398
]
@@ -424,8 +401,7 @@ def __init__(
424401
self._prefill_threads,
425402
self._transfer_threads,
426403
self._generate_threads,
427-
self.prefill_detokenize_threads,
428-
self.generate_detokenize_threads,
404+
self.detokenize_threads,
429405
)
430406
)
431407
self.live = True
@@ -444,8 +420,7 @@ def stop(self):
444420
[self._prefill_backlog],
445421
self._transfer_backlogs,
446422
self._generate_backlogs.values(),
447-
self._prefill_detokenize_backlogs,
448-
self._generate_detokenize_backlogs,
423+
self._detokenize_backlogs,
449424
)
450425
)
451426

@@ -561,7 +536,7 @@ def _prefill_thread(self, idx: int):
561536

562537
# put first token to detokenize queue
563538
request.complete = np.zeros((prefill_engine.samples_per_slot,), np.bool_)
564-
my_detokenize_backlog = self._prefill_detokenize_backlogs[idx]
539+
my_detokenize_backlog = self._detokenize_backlogs[idx]
565540
request.metadata.transfer_enqueue_time = time.perf_counter()
566541
my_detokenize_backlog.put(
567542
(first_token, request, request.metadata.prefill_dequeue_time),
@@ -657,7 +632,7 @@ def _generate_thread(self, idx: int):
657632
generate_engine = self._generate_engines[idx]
658633
my_slots = self._generate_slots[idx]
659634
my_generate_backlog = self._generate_backlogs[idx]
660-
my_detokenize_backlog = self._generate_detokenize_backlogs[idx]
635+
my_detokenize_backlog = self._detokenize_backlogs[idx]
661636

662637
# Keep track of what step tokens were generated at.
663638
generate_timestep = 0
@@ -787,17 +762,12 @@ def _generate_thread(self, idx: int):
787762
)
788763
time_of_last_generate = time.time()
789764

790-
def _detokenize_thread(self, is_prefill: bool, idx: int):
765+
def _detokenize_thread(self, idx: int):
791766
"""Detokenize sampled tokens and returns them to the user."""
792767
# One of these per generate engine.
793768
# For all filled my_slots, pop the sampled token onto the relevant
794769
# requests return channel. If it done, place it back onto free slots.
795-
796-
if is_prefill:
797-
my_detokenize_backlog = self._prefill_detokenize_backlogs[idx]
798-
else:
799-
my_detokenize_backlog = self._generate_detokenize_backlogs[idx]
800-
770+
my_detokenize_backlog = self._detokenize_backlogs[idx]
801771
my_generate_engine = self._generate_engines[idx]
802772
my_slots = self._generate_slots[idx]
803773

0 commit comments

Comments
 (0)