Skip to content

Commit 800012c

Browse files
Implement batch speculative decoding support
- Add batch speculative decoding functionality for batch_size > 1 - Update candidate generator to handle batch processing - Enhance generation utils with batch speculative decoding support - Add cache utilities for batch speculative decoding - Update tests for batch speculative decoding
1 parent ff13eb6 commit 800012c

File tree

6 files changed

+601
-116
lines changed

6 files changed

+601
-116
lines changed

src/transformers/cache_utils.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,75 @@ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
8080
self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
8181
self.values = self.values.index_select(0, beam_idx.to(self.values.device))
8282

83+
def align(
84+
self,
85+
new_seq_length: int,
86+
copy_instructions: list[tuple[int, slice, slice]],
87+
) -> None:
88+
"""
89+
Align this layer's cache based on copy instructions.
90+
91+
Args:
92+
new_seq_length (`int`): The new sequence length for the aligned cache.
93+
copy_instructions (`list[tuple[int, slice, slice]]`): List of (batch_idx, src_slice, dst_slice) tuples
94+
specifying what to copy from the old cache to the new cache.
95+
"""
96+
if not self.is_initialized:
97+
return
98+
99+
B, H, _, D = self.keys.shape
100+
new_keys = self.keys.new_zeros((B, H, new_seq_length, D))
101+
new_values = self.values.new_zeros((B, H, new_seq_length, D))
102+
103+
# Execute the pre-calculated copy instructions
104+
for i, src_slice, dst_slice in copy_instructions:
105+
new_keys[i, :, dst_slice] = self.keys[i, :, src_slice]
106+
new_values[i, :, dst_slice] = self.values[i, :, src_slice]
107+
108+
self.keys = new_keys
109+
self.values = new_values
110+
111+
def compress_and_repad_cache(self, padding_mask):
112+
# padding_mask: True = Pad, False = Keep
113+
B, H, S, D = self.keys.shape
114+
115+
# 1. Compute lengths and dimensions
116+
# Invert mask: True = Keep
117+
keep_mask = ~padding_mask # [B, S]
118+
lengths = keep_mask.sum(dim=1) # [B]
119+
max_len = lengths.max().item()
120+
121+
# 2. Allocate Output (Pre-filled with padding/zeros)
122+
# We allocate directly in the final shape [B, H, max_len, D]
123+
out_keys = self.keys.new_zeros((B, H, max_len, D))
124+
out_values = self.values.new_zeros((B, H, max_len, D))
125+
126+
# 3. Create the "Destination" mask for Left-Padding
127+
# We want valid data to sit at the END of the sequence (Left Padding)
128+
# Row i should have (max_len - length_i) pads, then valid data.
129+
130+
# shape: [max_len]
131+
range_tensor = torch.arange(max_len, device=self.keys.device)
132+
# shape: [B, max_len] broadcast comparison
133+
# Example: max_len=5, len=3. We want indices 2,3,4 to be True.
134+
# range (0,1,2,3,4) >= (5-3=2) -> F,F,T,T,T
135+
dest_mask = range_tensor >= (max_len - lengths.unsqueeze(1))
136+
137+
# 4. Perform the Copy (The Fast Part)
138+
# We transpose (B, H, S, D) -> (B, S, H, D) so the mask (B, S) aligns
139+
# This extracts ONLY the valid tokens into a flat buffer [Total_Valid, H, D]
140+
valid_keys = self.keys.transpose(1, 2)[keep_mask]
141+
valid_values = self.values.transpose(1, 2)[keep_mask]
142+
143+
# Assign into output using the destination mask
144+
# We transpose output to (B, max_len, H, D) to align with dest_mask (B, max_len)
145+
out_keys.transpose(1, 2)[dest_mask] = valid_keys
146+
out_values.transpose(1, 2)[dest_mask] = valid_values
147+
148+
# 5. Assign back
149+
self.keys = out_keys
150+
self.values = out_values
151+
83152

84153
class DynamicLayer(CacheLayerMixin):
85154
"""
@@ -891,6 +960,94 @@ def __len__(self):
891960
# forward through all the layers
892961
return len(self.layers)
893962

963+
def align(
964+
self,
965+
new_ids: torch.LongTensor,
966+
ids_in_cache: torch.LongTensor,
967+
pad_token_id: int,
968+
return_new_ids_in_cache: bool = False,
969+
):
970+
"""
971+
Align the cache when input sequences change (e.g., when batching different sequences together).
972+
973+
Args:
974+
new_ids (`torch.LongTensor`): The new input IDs after batching changes.
975+
ids_in_cache (`torch.LongTensor`): The input IDs that were used to build the current cache.
976+
pad_token_id (`int`): The padding token ID.
977+
return_new_ids_in_cache (`bool`, *optional*, defaults to `False`): Whether to return the aligned input IDs.
978+
979+
Returns:
980+
`None` if `return_new_ids_in_cache=False`, otherwise the aligned input IDs tensor.
981+
"""
982+
# 1. Setup metadata (Shape: [Batch, Heads, Sequence_Length, Dimension])
983+
# We access the first layer just to get shapes and device
984+
if len(self.layers) == 0 or not self.layers[0].is_initialized:
985+
raise ValueError("Cache is not initialized")
986+
987+
ref_layer = self.layers[0]
988+
B, H, S_old, D = ref_layer.keys.shape
989+
S_new = new_ids.shape[1] - 1 # Preserving your original sizing logic
990+
991+
# 2. Pre-calculate "What to copy" for the whole batch ONCE.
992+
993+
# Find start indices (Vectorized)
994+
# Note: sum() assumes left-padding only.
995+
old_start_indices = (ids_in_cache == pad_token_id).sum(dim=1)
996+
new_start_indices = (new_ids == pad_token_id).sum(dim=1)
997+
998+
# We will store the copy instructions here to apply to all layers later
999+
# Format: List of tuples (batch_idx, source_slice, dest_slice)
1000+
copy_instructions = []
1001+
1002+
# We still loop over batch (B), but only once, not B * Layers
1003+
for i in range(B):
1004+
# Identify the content without padding
1005+
# We use standard python slicing here as it's just index math, very fast
1006+
o_start = old_start_indices[i].item()
1007+
n_start = new_start_indices[i].item()
1008+
1009+
# Get the actual token sequences (views, not copies)
1010+
# We perform the comparison on the ID tensors (int64), which is cheap
1011+
trimmed_old = ids_in_cache[i, o_start:]
1012+
trimmed_new = new_ids[i, n_start:]
1013+
1014+
min_len = min(len(trimmed_old), len(trimmed_new))
1015+
1016+
# Compare only up to min_len
1017+
# Using .ne() (not equal) and finding the first true is faster than checks
1018+
if min_len == 0:
1019+
copy_len = 0
1020+
else:
1021+
# Find mismatch: (a != b)
1022+
mismatch = trimmed_old[:min_len].ne(trimmed_new[:min_len])
1023+
if not mismatch.any():
1024+
copy_len = min_len
1025+
else:
1026+
# argmax on boolean gives index of first True
1027+
copy_len = mismatch.int().argmax().item()
1028+
1029+
if copy_len > 0:
1030+
# Define the slice objects now so we don't recreate them 32 times
1031+
src_slice = slice(o_start, o_start + copy_len)
1032+
# You align to the right (-length:)
1033+
dst_slice = slice(-copy_len, None)
1034+
copy_instructions.append((i, src_slice, dst_slice))
1035+
1036+
# 3. Apply changes to all layers using per-layer align method
1037+
for layer in self.layers:
1038+
layer.align(S_new, copy_instructions)
1039+
1040+
if return_new_ids_in_cache:
1041+
new_input_ids_in_cache = ids_in_cache.new_zeros((B, S_new))
1042+
# Execute the copy instructions for input IDs
1043+
for i, src_slice, dst_slice in copy_instructions:
1044+
new_input_ids_in_cache[i, dst_slice] = ids_in_cache[i, src_slice]
1045+
return new_input_ids_in_cache
1046+
1047+
def compress_and_repad_cache(self, padding_mask):
1048+
for layer in self.layers:
1049+
layer.compress_and_repad_cache(padding_mask)
1050+
8941051

8951052
class DynamicCache(Cache):
8961053
"""
@@ -1277,6 +1434,35 @@ def batch_select_indices(self, indices: torch.Tensor):
12771434
self.self_attention_cache.batch_select_indices(indices)
12781435
self.cross_attention_cache.batch_select_indices(indices)
12791436

1437+
def align(
1438+
self,
1439+
new_ids: torch.LongTensor,
1440+
ids_in_cache: torch.LongTensor,
1441+
pad_token_id: int,
1442+
return_new_ids_in_cache: bool = False,
1443+
):
1444+
"""
1445+
Align the cache when input sequences change (e.g., when batching different sequences together).
1446+
This aligns both self-attention and cross-attention caches.
1447+
1448+
Args:
1449+
new_ids (`torch.LongTensor`): The new input IDs after batching changes.
1450+
ids_in_cache (`torch.LongTensor`): The input IDs that were used to build the current cache.
1451+
pad_token_id (`int`): The padding token ID.
1452+
return_new_ids_in_cache (`bool`, *optional*, defaults to `False`): Whether to return the aligned input IDs.
1453+
1454+
Returns:
1455+
`None` if `return_new_ids_in_cache=False`, otherwise the aligned input IDs tensor.
1456+
"""
1457+
if return_new_ids_in_cache:
1458+
aligned_ids = self.self_attention_cache.align(new_ids, ids_in_cache, pad_token_id, return_new_ids_in_cache)
1459+
return aligned_ids
1460+
else:
1461+
self.self_attention_cache.align(new_ids, ids_in_cache, pad_token_id, return_new_ids_in_cache)
1462+
1463+
def compress_and_repad_cache(self, padding_mask):
1464+
self.self_attention_cache.compress_and_repad_cache(padding_mask)
1465+
12801466
def get_max_cache_shape(self) -> int:
12811467
"""Returns the maximum sequence length (i.e. max capacity) of the cache object"""
12821468
return self.self_attention_cache.get_max_cache_shape()

0 commit comments

Comments
 (0)