Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,75 @@ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
self.values = self.values.index_select(0, beam_idx.to(self.values.device))

def align(
self,
new_seq_length: int,
copy_instructions: list[tuple[int, slice, slice]],
) -> None:
"""
Align this layer's cache based on copy instructions.

Args:
new_seq_length (`int`): The new sequence length for the aligned cache.
copy_instructions (`list[tuple[int, slice, slice]]`): List of (batch_idx, src_slice, dst_slice) tuples
specifying what to copy from the old cache to the new cache.
"""
if not self.is_initialized:
return

B, H, _, D = self.keys.shape
new_keys = self.keys.new_zeros((B, H, new_seq_length, D))
new_values = self.values.new_zeros((B, H, new_seq_length, D))

# Execute the pre-calculated copy instructions
for i, src_slice, dst_slice in copy_instructions:
new_keys[i, :, dst_slice] = self.keys[i, :, src_slice]
new_values[i, :, dst_slice] = self.values[i, :, src_slice]

self.keys = new_keys
self.values = new_values

def compress_and_repad_cache(self, padding_mask):
# padding_mask: True = Pad, False = Keep
B, H, S, D = self.keys.shape

# 1. Compute lengths and dimensions
# Invert mask: True = Keep
keep_mask = ~padding_mask # [B, S]
lengths = keep_mask.sum(dim=1) # [B]
max_len = lengths.max().item()

# 2. Allocate Output (Pre-filled with padding/zeros)
# We allocate directly in the final shape [B, H, max_len, D]
out_keys = self.keys.new_zeros((B, H, max_len, D))
out_values = self.values.new_zeros((B, H, max_len, D))

# 3. Create the "Destination" mask for Left-Padding
# We want valid data to sit at the END of the sequence (Left Padding)
# Row i should have (max_len - length_i) pads, then valid data.

# shape: [max_len]
range_tensor = torch.arange(max_len, device=self.keys.device)
# shape: [B, max_len] broadcast comparison
# Example: max_len=5, len=3. We want indices 2,3,4 to be True.
# range (0,1,2,3,4) >= (5-3=2) -> F,F,T,T,T
dest_mask = range_tensor >= (max_len - lengths.unsqueeze(1))

# 4. Perform the Copy (The Fast Part)
# We transpose (B, H, S, D) -> (B, S, H, D) so the mask (B, S) aligns
# This extracts ONLY the valid tokens into a flat buffer [Total_Valid, H, D]
valid_keys = self.keys.transpose(1, 2)[keep_mask]
valid_values = self.values.transpose(1, 2)[keep_mask]

# Assign into output using the destination mask
# We transpose output to (B, max_len, H, D) to align with dest_mask (B, max_len)
out_keys.transpose(1, 2)[dest_mask] = valid_keys
out_values.transpose(1, 2)[dest_mask] = valid_values

# 5. Assign back
self.keys = out_keys
self.values = out_values


class DynamicLayer(CacheLayerMixin):
"""
Expand Down Expand Up @@ -891,6 +960,94 @@ def __len__(self):
# forward through all the layers
return len(self.layers)

def align(
self,
new_ids: torch.LongTensor,
ids_in_cache: torch.LongTensor,
pad_token_id: int,
return_new_ids_in_cache: bool = False,
):
"""
Align the cache when input sequences change (e.g., when batching different sequences together).

Args:
new_ids (`torch.LongTensor`): The new input IDs after batching changes.
ids_in_cache (`torch.LongTensor`): The input IDs that were used to build the current cache.
pad_token_id (`int`): The padding token ID.
return_new_ids_in_cache (`bool`, *optional*, defaults to `False`): Whether to return the aligned input IDs.

Returns:
`None` if `return_new_ids_in_cache=False`, otherwise the aligned input IDs tensor.
"""
# 1. Setup metadata (Shape: [Batch, Heads, Sequence_Length, Dimension])
# We access the first layer just to get shapes and device
if len(self.layers) == 0 or not self.layers[0].is_initialized:
raise ValueError("Cache is not initialized")

ref_layer = self.layers[0]
B, H, S_old, D = ref_layer.keys.shape
S_new = new_ids.shape[1] - 1 # Preserving your original sizing logic

# 2. Pre-calculate "What to copy" for the whole batch ONCE.

# Find start indices (Vectorized)
# Note: sum() assumes left-padding only.
old_start_indices = (ids_in_cache == pad_token_id).sum(dim=1)
new_start_indices = (new_ids == pad_token_id).sum(dim=1)

# We will store the copy instructions here to apply to all layers later
# Format: List of tuples (batch_idx, source_slice, dest_slice)
copy_instructions = []

# We still loop over batch (B), but only once, not B * Layers
for i in range(B):
# Identify the content without padding
# We use standard python slicing here as it's just index math, very fast
o_start = old_start_indices[i].item()
n_start = new_start_indices[i].item()

# Get the actual token sequences (views, not copies)
# We perform the comparison on the ID tensors (int64), which is cheap
trimmed_old = ids_in_cache[i, o_start:]
trimmed_new = new_ids[i, n_start:]

min_len = min(len(trimmed_old), len(trimmed_new))

# Compare only up to min_len
# Using .ne() (not equal) and finding the first true is faster than checks
if min_len == 0:
copy_len = 0
else:
# Find mismatch: (a != b)
mismatch = trimmed_old[:min_len].ne(trimmed_new[:min_len])
if not mismatch.any():
copy_len = min_len
else:
# argmax on boolean gives index of first True
copy_len = mismatch.int().argmax().item()

if copy_len > 0:
# Define the slice objects now so we don't recreate them 32 times
src_slice = slice(o_start, o_start + copy_len)
# You align to the right (-length:)
dst_slice = slice(-copy_len, None)
copy_instructions.append((i, src_slice, dst_slice))

# 3. Apply changes to all layers using per-layer align method
for layer in self.layers:
layer.align(S_new, copy_instructions)

if return_new_ids_in_cache:
new_input_ids_in_cache = ids_in_cache.new_zeros((B, S_new))
# Execute the copy instructions for input IDs
for i, src_slice, dst_slice in copy_instructions:
new_input_ids_in_cache[i, dst_slice] = ids_in_cache[i, src_slice]
return new_input_ids_in_cache

def compress_and_repad_cache(self, padding_mask):
for layer in self.layers:
layer.compress_and_repad_cache(padding_mask)


class DynamicCache(Cache):
"""
Expand Down Expand Up @@ -1277,6 +1434,35 @@ def batch_select_indices(self, indices: torch.Tensor):
self.self_attention_cache.batch_select_indices(indices)
self.cross_attention_cache.batch_select_indices(indices)

def align(
self,
new_ids: torch.LongTensor,
ids_in_cache: torch.LongTensor,
pad_token_id: int,
return_new_ids_in_cache: bool = False,
):
"""
Align the cache when input sequences change (e.g., when batching different sequences together).
This aligns both self-attention and cross-attention caches.

Args:
new_ids (`torch.LongTensor`): The new input IDs after batching changes.
ids_in_cache (`torch.LongTensor`): The input IDs that were used to build the current cache.
pad_token_id (`int`): The padding token ID.
return_new_ids_in_cache (`bool`, *optional*, defaults to `False`): Whether to return the aligned input IDs.

Returns:
`None` if `return_new_ids_in_cache=False`, otherwise the aligned input IDs tensor.
"""
if return_new_ids_in_cache:
aligned_ids = self.self_attention_cache.align(new_ids, ids_in_cache, pad_token_id, return_new_ids_in_cache)
return aligned_ids
else:
self.self_attention_cache.align(new_ids, ids_in_cache, pad_token_id, return_new_ids_in_cache)

def compress_and_repad_cache(self, padding_mask):
self.self_attention_cache.compress_and_repad_cache(padding_mask)

def get_max_cache_shape(self) -> int:
"""Returns the maximum sequence length (i.e. max capacity) of the cache object"""
return self.self_attention_cache.get_max_cache_shape()
Expand Down
Loading