Skip to content

Commit 4bf1fb8

Browse files
fixes
1 parent 76db8da commit 4bf1fb8

File tree

5 files changed

+235
-161
lines changed

5 files changed

+235
-161
lines changed

src/transformers/cache_utils.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,34 @@ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
8080
self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
8181
self.values = self.values.index_select(0, beam_idx.to(self.values.device))
8282

83+
def align(
84+
self,
85+
new_seq_length: int,
86+
copy_instructions: list[tuple[int, slice, slice]],
87+
) -> None:
88+
"""
89+
Align this layer's cache based on copy instructions.
90+
91+
Args:
92+
new_seq_length (`int`): The new sequence length for the aligned cache.
93+
copy_instructions (`list[tuple[int, slice, slice]]`): List of (batch_idx, src_slice, dst_slice) tuples
94+
specifying what to copy from the old cache to the new cache.
95+
"""
96+
if not self.is_initialized:
97+
return
98+
99+
B, H, _, D = self.keys.shape
100+
new_keys = self.keys.new_zeros((B, H, new_seq_length, D))
101+
new_values = self.values.new_zeros((B, H, new_seq_length, D))
102+
103+
# Execute the pre-calculated copy instructions
104+
for i, src_slice, dst_slice in copy_instructions:
105+
new_keys[i, :, dst_slice] = self.keys[i, :, src_slice]
106+
new_values[i, :, dst_slice] = self.values[i, :, src_slice]
107+
108+
self.keys = new_keys
109+
self.values = new_values
110+
83111

84112
class DynamicLayer(CacheLayerMixin):
85113
"""
@@ -891,6 +919,90 @@ def __len__(self):
891919
# forward through all the layers
892920
return len(self.layers)
893921

922+
def align(
923+
self,
924+
new_ids: torch.LongTensor,
925+
ids_in_cache: torch.LongTensor,
926+
pad_token_id: int,
927+
return_new_ids_in_cache: bool = False,
928+
):
929+
"""
930+
Align the cache when input sequences change (e.g., when batching different sequences together).
931+
932+
Args:
933+
new_ids (`torch.LongTensor`): The new input IDs after batching changes.
934+
ids_in_cache (`torch.LongTensor`): The input IDs that were used to build the current cache.
935+
pad_token_id (`int`): The padding token ID.
936+
return_new_ids_in_cache (`bool`, *optional*, defaults to `False`): Whether to return the aligned input IDs.
937+
938+
Returns:
939+
`None` if `return_new_ids_in_cache=False`, otherwise the aligned input IDs tensor.
940+
"""
941+
# 1. Setup metadata (Shape: [Batch, Heads, Sequence_Length, Dimension])
942+
# We access the first layer just to get shapes and device
943+
if len(self.layers) == 0 or not self.layers[0].is_initialized:
944+
raise ValueError("Cache is not initialized")
945+
946+
ref_layer = self.layers[0]
947+
B, H, S_old, D = ref_layer.keys.shape
948+
S_new = new_ids.shape[1] - 1 # Preserving your original sizing logic
949+
950+
# 2. Pre-calculate "What to copy" for the whole batch ONCE.
951+
952+
# Find start indices (Vectorized)
953+
# Note: sum() assumes left-padding only.
954+
old_start_indices = (ids_in_cache == pad_token_id).sum(dim=1)
955+
new_start_indices = (new_ids == pad_token_id).sum(dim=1)
956+
957+
# We will store the copy instructions here to apply to all layers later
958+
# Format: List of tuples (batch_idx, source_slice, dest_slice)
959+
copy_instructions = []
960+
961+
# We still loop over batch (B), but only once, not B * Layers
962+
for i in range(B):
963+
# Identify the content without padding
964+
# We use standard python slicing here as it's just index math, very fast
965+
o_start = old_start_indices[i].item()
966+
n_start = new_start_indices[i].item()
967+
968+
# Get the actual token sequences (views, not copies)
969+
# We perform the comparison on the ID tensors (int64), which is cheap
970+
trimmed_old = ids_in_cache[i, o_start:]
971+
trimmed_new = new_ids[i, n_start:]
972+
973+
min_len = min(len(trimmed_old), len(trimmed_new))
974+
975+
# Compare only up to min_len
976+
# Using .ne() (not equal) and finding the first true is faster than checks
977+
if min_len == 0:
978+
copy_len = 0
979+
else:
980+
# Find mismatch: (a != b)
981+
mismatch = trimmed_old[:min_len].ne(trimmed_new[:min_len])
982+
if not mismatch.any():
983+
copy_len = min_len
984+
else:
985+
# argmax on boolean gives index of first True
986+
copy_len = mismatch.int().argmax().item()
987+
988+
if copy_len > 0:
989+
# Define the slice objects now so we don't recreate them 32 times
990+
src_slice = slice(o_start, o_start + copy_len)
991+
# You align to the right (-length:)
992+
dst_slice = slice(-copy_len, None)
993+
copy_instructions.append((i, src_slice, dst_slice))
994+
995+
# 3. Apply changes to all layers using per-layer align method
996+
for layer in self.layers:
997+
layer.align(S_new, copy_instructions)
998+
999+
if return_new_ids_in_cache:
1000+
new_input_ids_in_cache = ids_in_cache.new_zeros((B, S_new))
1001+
# Execute the copy instructions for input IDs
1002+
for i, src_slice, dst_slice in copy_instructions:
1003+
new_input_ids_in_cache[i, dst_slice] = ids_in_cache[i, src_slice]
1004+
return new_input_ids_in_cache
1005+
8941006

8951007
class DynamicCache(Cache):
8961008
"""
@@ -1277,6 +1389,32 @@ def batch_select_indices(self, indices: torch.Tensor):
12771389
self.self_attention_cache.batch_select_indices(indices)
12781390
self.cross_attention_cache.batch_select_indices(indices)
12791391

1392+
def align(
1393+
self,
1394+
new_ids: torch.LongTensor,
1395+
ids_in_cache: torch.LongTensor,
1396+
pad_token_id: int,
1397+
return_new_ids_in_cache: bool = False,
1398+
):
1399+
"""
1400+
Align the cache when input sequences change (e.g., when batching different sequences together).
1401+
This aligns both self-attention and cross-attention caches.
1402+
1403+
Args:
1404+
new_ids (`torch.LongTensor`): The new input IDs after batching changes.
1405+
ids_in_cache (`torch.LongTensor`): The input IDs that were used to build the current cache.
1406+
pad_token_id (`int`): The padding token ID.
1407+
return_new_ids_in_cache (`bool`, *optional*, defaults to `False`): Whether to return the aligned input IDs.
1408+
1409+
Returns:
1410+
`None` if `return_new_ids_in_cache=False`, otherwise the aligned input IDs tensor.
1411+
"""
1412+
if return_new_ids_in_cache:
1413+
aligned_ids = self.self_attention_cache.align(new_ids, ids_in_cache, pad_token_id, return_new_ids_in_cache)
1414+
return aligned_ids
1415+
else:
1416+
self.self_attention_cache.align(new_ids, ids_in_cache, pad_token_id, return_new_ids_in_cache)
1417+
12801418
def get_max_cache_shape(self) -> int:
12811419
"""Returns the maximum sequence length (i.e. max capacity) of the cache object"""
12821420
return self.self_attention_cache.get_max_cache_shape()

src/transformers/generation/candidate_generator.py

Lines changed: 21 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,16 @@
4141
class CandidateGenerator:
4242
"""Abstract base class for all candidate generators that can be applied during assisted generation."""
4343

44-
def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, torch.FloatTensor | None]:
44+
def get_candidates(
45+
self, input_ids: torch.LongTensor, assistant_ids_in_cache: torch.LongTensor = None
46+
) -> tuple[torch.LongTensor, torch.FloatTensor | None]:
4547
"""
4648
Fetches the candidates to be tried for the current input.
4749
4850
Args:
4951
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
5052
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
53+
assistant_ids_in_cache (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
5154
5255
Return:
5356
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
@@ -248,6 +251,7 @@ def update_candidate_strategy(
248251
"""
249252
# Handle backward compatibility: convert int to tensor
250253
if isinstance(num_matches, int):
254+
assert input_ids.shape[0] == 1, "num_matches should be a tensor of shape (batch_size,) when batch_size > 1"
251255
num_matches = torch.tensor([num_matches], device=input_ids.device)
252256

253257
batch_size = input_ids.shape[0]
@@ -332,13 +336,12 @@ def _update_past_and_masks(
332336
"""Update past key values and attention masks for subsequent generation rounds."""
333337
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
334338
if has_past_key_values:
335-
new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv
336-
current_cache = self.assistant_kwargs["past_key_values"]
337339
if input_ids.shape[0] > 1:
338-
self.assistant_kwargs["past_key_values"] = align_cache(
339-
current_cache, input_ids, assistant_ids_in_cache, self.generation_config.pad_token_id
340+
self.assistant_kwargs["past_key_values"].align(
341+
input_ids, assistant_ids_in_cache, self.generation_config.pad_token_id
340342
)
341343
else:
344+
new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv
342345
self.assistant_kwargs["past_key_values"].crop(new_cache_size - num_added_tokens)
343346
self.assistant_kwargs = _prepare_attention_mask(
344347
self.assistant_kwargs,
@@ -561,7 +564,8 @@ def get_candidates(
561564
Args:
562565
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
563566
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
564-
567+
assistant_ids_in_cache (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
568+
Indices of input sequence tokens in the assistant vocabulary that are in the cache.
565569
Return:
566570
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
567571
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
@@ -1116,13 +1120,17 @@ def __init__(
11161120
if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
11171121
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
11181122

1119-
def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, torch.FloatTensor | None]:
1123+
def get_candidates(
1124+
self, input_ids: torch.LongTensor, assistant_ids_in_cache: torch.LongTensor = None
1125+
) -> tuple[torch.LongTensor, torch.FloatTensor | None]:
11201126
"""
11211127
Fetches the candidates to be tried for the current input.
11221128
11231129
Args:
11241130
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
11251131
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
1132+
assistant_ids_in_cache (`torch.LongTensor`, *optional*):
1133+
Assistant model input IDs that are already in the cache. Not used by prompt lookup decoding.
11261134
11271135
Return:
11281136
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
@@ -1277,12 +1285,16 @@ def __init__(
12771285
self.assistant_early_exit = self.generation_config.assistant_early_exit
12781286
self.generation_config.assistant_early_exit = None
12791287

1280-
def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, torch.FloatTensor | None]:
1288+
def get_candidates(
1289+
self, input_ids: torch.LongTensor, assistant_ids_in_cache: torch.LongTensor = None
1290+
) -> tuple[torch.LongTensor, torch.FloatTensor | None]:
12811291
# Temporarily sets the number of hidden layers to the early exit value
12821292
base_model = getattr(self.assistant_model, self.assistant_model.base_model_prefix)
12831293
original_num_hidden_layers = base_model.config.num_hidden_layers
12841294
base_model.config.num_hidden_layers = self.assistant_early_exit
1285-
candidate_ids, candidate_logits = super().get_candidates(input_ids)
1295+
candidate_ids, candidate_logits = super().get_candidates(
1296+
input_ids, assistant_ids_in_cache=assistant_ids_in_cache
1297+
)
12861298
base_model.config.num_hidden_layers = original_num_hidden_layers
12871299
return candidate_ids, candidate_logits
12881300

@@ -1345,84 +1357,3 @@ def _prepare_token_type_ids(model_kwargs: dict[str, Any], new_length: int) -> di
13451357
token_type_copies = final_token_type.repeat(1, type_length_diff)
13461358
model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1)
13471359
return model_kwargs
1348-
1349-
1350-
def align_cache(cache, new_ids, assistant_ids_in_cache, pad_token_id, apply_same_transform_on_old_ids: bool = False):
1351-
# 1. Setup metadata (Shape: [Batch, Heads, Sequence_Length, Dimension])
1352-
# We access the first layer just to get shapes and device
1353-
1354-
ref_layer = cache.layers[0]
1355-
old_ids = assistant_ids_in_cache
1356-
B, H, S_old, D = ref_layer.keys.shape
1357-
S_new = new_ids.shape[1] - 1 # Preserving your original sizing logic
1358-
1359-
# 2. Pre-calculate "What to copy" for the whole batch ONCE.
1360-
# This removes the logic calculation from the inner loop (32x-80x speedup on logic)
1361-
1362-
# Find start indices (Vectorized)
1363-
# Note: sum() assumes left-padding only.
1364-
old_start_indices = (old_ids == pad_token_id).sum(dim=1)
1365-
new_start_indices = (new_ids == pad_token_id).sum(dim=1)
1366-
1367-
# We will store the copy instructions here to apply to all layers later
1368-
# Format: List of tuples (batch_idx, source_slice, dest_slice, copy_len)
1369-
copy_instructions = []
1370-
1371-
# We still loop over batch (B), but only once, not B * Layers
1372-
for i in range(B):
1373-
# Identify the content without padding
1374-
# We use standard python slicing here as it's just index math, very fast
1375-
o_start = old_start_indices[i].item()
1376-
n_start = new_start_indices[i].item()
1377-
1378-
# Get the actual token sequences (views, not copies)
1379-
# We perform the comparison on the ID tensors (int64), which is cheap
1380-
trimmed_old = old_ids[i, o_start:]
1381-
trimmed_new = new_ids[i, n_start:]
1382-
1383-
min_len = min(len(trimmed_old), len(trimmed_new))
1384-
1385-
# Compare only up to min_len
1386-
# Using .ne() (not equal) and finding the first true is faster than checks
1387-
if min_len == 0:
1388-
copy_len = 0
1389-
else:
1390-
# Find mismatch: (a != b)
1391-
mismatch = trimmed_old[:min_len].ne(trimmed_new[:min_len])
1392-
if not mismatch.any():
1393-
copy_len = min_len
1394-
else:
1395-
# argmax on boolean gives index of first True
1396-
copy_len = mismatch.int().argmax().item()
1397-
1398-
if copy_len > 0:
1399-
# Define the slice objects now so we don't recreate them 32 times
1400-
src_slice = slice(o_start, o_start + copy_len)
1401-
# You align to the right (-length:)
1402-
dst_slice = slice(-copy_len, None)
1403-
copy_instructions.append((i, src_slice, dst_slice))
1404-
1405-
# 3. Apply changes to all layers
1406-
# We allocate new tensors and copy in bulk based on pre-calculated instructions
1407-
new_input_ids_in_cache = None
1408-
for layer in cache.layers:
1409-
# Allocation (This is the heavy GPU/Memory op)
1410-
new_keys = layer.keys.new_zeros((B, H, S_new, D))
1411-
new_values = layer.values.new_zeros((B, H, S_new, D))
1412-
if apply_same_transform_on_old_ids and new_input_ids_in_cache is None:
1413-
new_input_ids_in_cache = assistant_ids_in_cache.new_zeros((B, S_new))
1414-
# Execute the pre-calculated copy instructions
1415-
for i, src_slice, dst_slice in copy_instructions:
1416-
# Copy Keys/Values
1417-
new_keys[i, :, dst_slice] = layer.keys[i, :, src_slice]
1418-
new_values[i, :, dst_slice] = layer.values[i, :, src_slice]
1419-
1420-
if apply_same_transform_on_old_ids and new_input_ids_in_cache is not None:
1421-
new_input_ids_in_cache[i, dst_slice] = assistant_ids_in_cache[i, src_slice]
1422-
# Update the layer
1423-
layer.keys = new_keys
1424-
layer.values = new_values
1425-
1426-
if apply_same_transform_on_old_ids:
1427-
return cache, new_input_ids_in_cache
1428-
return cache

src/transformers/generation/utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3649,8 +3649,16 @@ def _assisted_decoding(
36493649

36503650
# keep track of which sequences are already finished
36513651
batch_size, cur_len = input_ids.shape[:2]
3652-
if batch_size > 1 and assistant_tokenizer is not None:
3653-
raise ValueError("assisted generate is only supported for batch_size > 1 if assistant_tokenizer is None")
3652+
if batch_size > 1:
3653+
if assistant_tokenizer is not None:
3654+
raise ValueError(
3655+
"assisted generate is only supported for batch_size > 1 if assistant_tokenizer is None"
3656+
)
3657+
if generation_config.prompt_lookup_num_tokens is not None:
3658+
raise ValueError(
3659+
"assisted generate is only supported for batch_size > 1 if prompt_lookup_num_tokens is None"
3660+
)
3661+
36543662
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
36553663
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
36563664

@@ -4046,7 +4054,7 @@ def repadd_batch_and_fix_cache(input_ids, past_key_values, accepted_tokens_padde
40464054
padding_mask = cache_input_ids[:, :-1] == pad_token_id
40474055

40484056
for layer in past_key_values.layers:
4049-
layer = compress_and_repad_cache(layer, padding_mask, pad_token_id)
4057+
compress_and_repad_cache(layer, padding_mask)
40504058
# 1. Filter out current padding and repad to minimum length.
40514059
next_input_ids_clean = [row[row != pad_token_id] for row in next_input_ids]
40524060
next_input_ids_padded = pad_sequence(
@@ -4055,7 +4063,7 @@ def repadd_batch_and_fix_cache(input_ids, past_key_values, accepted_tokens_padde
40554063
return next_input_ids_padded, past_key_values
40564064

40574065

4058-
def compress_and_repad_cache(layer, padding_mask, pad_token_id):
4066+
def compress_and_repad_cache(layer, padding_mask):
40594067
# padding_mask: True = Pad, False = Keep
40604068
B, H, S, D = layer.keys.shape
40614069

@@ -4095,5 +4103,3 @@ def compress_and_repad_cache(layer, padding_mask, pad_token_id):
40954103
# 5. Assign back
40964104
layer.keys = out_keys
40974105
layer.values = out_values
4098-
4099-
return layer

0 commit comments

Comments
 (0)