New Feature: Enabling Speculative Decoding with Batch Size > 1 (If draft and target model share tokenizer) #42655
+601
−116
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Many discussions focused on speculative decoding with 'batch_size > 1' (#26875, #29769, #32165, #32189), but none were fully implemented. This PR does that.
Summary of Changes: Enable Batched Speculative Decoding
Until now, invoking
generate()with anassistant_modelraised aValueErrorifbatch_size > 1. This update modifies the codebase to support batched speculative decoding.Key Features & Limitations
batch_size > 1for standard speculative decoding.batch_size = 1and will still raise aValueErrorif used with a batch.Technical Implementation Details
Supporting batches requires dealing with "ragged tensors", where different sequences in a batch accept a different number of speculative tokens, leading to misalignment in tensor shapes and KV caches. The implementation addresses this via the following mechanisms:
1. Assistant Cache Alignment
align_cachefunction was implemented incandidate_generator.py. It identify which parts of the old cache match the new inputs and removing the irrelevant keys and values from the cache. This ensures the assistant's cache correctly reflects the accepted tokens.2. Repadding Inputs and Main Model Cache
3. Vectorized Verification & Sampling
cumsumoperations to identify the index of the first mismatch for every item in the batch simultaneously._speculative_samplinghas been updated to calculate acceptance probabilities (4. Dynamic Heuristics (
update_candidate_strategy)num_assistant_tokens) now aggregates statistics across the batch (e.g., using the average number of matches) to determine if the draft model step size should increase or decrease.Note on
max_new_tokensWhen using speculative decoding with
batch_size > 1, the behavior regardingmax_new_tokensrequires nuance. Because the generation loop serves the entire batch, the process continues until one items satisfy the stopping criteria. However, due to the repadding and varying acceptance rates, other items in the batch may effectively stop generating before reaching the globalmax_new_tokenslimit.Batch Size = 3: Assistant vs Standard Decoding
WITHOUT assistant_model: 9.2215 seconds
WITH assistant_model: 7.4744 seconds
I look forward to your response and comments to ensure this can be included in the next version.