@@ -80,6 +80,75 @@ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
8080 self .keys = self .keys .index_select (0 , beam_idx .to (self .keys .device ))
8181 self .values = self .values .index_select (0 , beam_idx .to (self .values .device ))
8282
83+ def align (
84+ self ,
85+ new_seq_length : int ,
86+ copy_instructions : list [tuple [int , slice , slice ]],
87+ ) -> None :
88+ """
89+ Align this layer's cache based on copy instructions.
90+
91+ Args:
92+ new_seq_length (`int`): The new sequence length for the aligned cache.
93+ copy_instructions (`list[tuple[int, slice, slice]]`): List of (batch_idx, src_slice, dst_slice) tuples
94+ specifying what to copy from the old cache to the new cache.
95+ """
96+ if not self .is_initialized :
97+ return
98+
99+ B , H , _ , D = self .keys .shape
100+ new_keys = self .keys .new_zeros ((B , H , new_seq_length , D ))
101+ new_values = self .values .new_zeros ((B , H , new_seq_length , D ))
102+
103+ # Execute the pre-calculated copy instructions
104+ for i , src_slice , dst_slice in copy_instructions :
105+ new_keys [i , :, dst_slice ] = self .keys [i , :, src_slice ]
106+ new_values [i , :, dst_slice ] = self .values [i , :, src_slice ]
107+
108+ self .keys = new_keys
109+ self .values = new_values
110+
111+ def compress_and_repad_cache (self , padding_mask ):
112+ # padding_mask: True = Pad, False = Keep
113+ B , H , S , D = self .keys .shape
114+
115+ # 1. Compute lengths and dimensions
116+ # Invert mask: True = Keep
117+ keep_mask = ~ padding_mask # [B, S]
118+ lengths = keep_mask .sum (dim = 1 ) # [B]
119+ max_len = lengths .max ().item ()
120+
121+ # 2. Allocate Output (Pre-filled with padding/zeros)
122+ # We allocate directly in the final shape [B, H, max_len, D]
123+ out_keys = self .keys .new_zeros ((B , H , max_len , D ))
124+ out_values = self .values .new_zeros ((B , H , max_len , D ))
125+
126+ # 3. Create the "Destination" mask for Left-Padding
127+ # We want valid data to sit at the END of the sequence (Left Padding)
128+ # Row i should have (max_len - length_i) pads, then valid data.
129+
130+ # shape: [max_len]
131+ range_tensor = torch .arange (max_len , device = self .keys .device )
132+ # shape: [B, max_len] broadcast comparison
133+ # Example: max_len=5, len=3. We want indices 2,3,4 to be True.
134+ # range (0,1,2,3,4) >= (5-3=2) -> F,F,T,T,T
135+ dest_mask = range_tensor >= (max_len - lengths .unsqueeze (1 ))
136+
137+ # 4. Perform the Copy (The Fast Part)
138+ # We transpose (B, H, S, D) -> (B, S, H, D) so the mask (B, S) aligns
139+ # This extracts ONLY the valid tokens into a flat buffer [Total_Valid, H, D]
140+ valid_keys = self .keys .transpose (1 , 2 )[keep_mask ]
141+ valid_values = self .values .transpose (1 , 2 )[keep_mask ]
142+
143+ # Assign into output using the destination mask
144+ # We transpose output to (B, max_len, H, D) to align with dest_mask (B, max_len)
145+ out_keys .transpose (1 , 2 )[dest_mask ] = valid_keys
146+ out_values .transpose (1 , 2 )[dest_mask ] = valid_values
147+
148+ # 5. Assign back
149+ self .keys = out_keys
150+ self .values = out_values
151+
83152
84153class DynamicLayer (CacheLayerMixin ):
85154 """
@@ -891,6 +960,94 @@ def __len__(self):
891960 # forward through all the layers
892961 return len (self .layers )
893962
963+ def align (
964+ self ,
965+ new_ids : torch .LongTensor ,
966+ ids_in_cache : torch .LongTensor ,
967+ pad_token_id : int ,
968+ return_new_ids_in_cache : bool = False ,
969+ ):
970+ """
971+ Align the cache when input sequences change (e.g., when batching different sequences together).
972+
973+ Args:
974+ new_ids (`torch.LongTensor`): The new input IDs after batching changes.
975+ ids_in_cache (`torch.LongTensor`): The input IDs that were used to build the current cache.
976+ pad_token_id (`int`): The padding token ID.
977+ return_new_ids_in_cache (`bool`, *optional*, defaults to `False`): Whether to return the aligned input IDs.
978+
979+ Returns:
980+ `None` if `return_new_ids_in_cache=False`, otherwise the aligned input IDs tensor.
981+ """
982+ # 1. Setup metadata (Shape: [Batch, Heads, Sequence_Length, Dimension])
983+ # We access the first layer just to get shapes and device
984+ if len (self .layers ) == 0 or not self .layers [0 ].is_initialized :
985+ raise ValueError ("Cache is not initialized" )
986+
987+ ref_layer = self .layers [0 ]
988+ B , H , S_old , D = ref_layer .keys .shape
989+ S_new = new_ids .shape [1 ] - 1 # Preserving your original sizing logic
990+
991+ # 2. Pre-calculate "What to copy" for the whole batch ONCE.
992+
993+ # Find start indices (Vectorized)
994+ # Note: sum() assumes left-padding only.
995+ old_start_indices = (ids_in_cache == pad_token_id ).sum (dim = 1 )
996+ new_start_indices = (new_ids == pad_token_id ).sum (dim = 1 )
997+
998+ # We will store the copy instructions here to apply to all layers later
999+ # Format: List of tuples (batch_idx, source_slice, dest_slice)
1000+ copy_instructions = []
1001+
1002+ # We still loop over batch (B), but only once, not B * Layers
1003+ for i in range (B ):
1004+ # Identify the content without padding
1005+ # We use standard python slicing here as it's just index math, very fast
1006+ o_start = old_start_indices [i ].item ()
1007+ n_start = new_start_indices [i ].item ()
1008+
1009+ # Get the actual token sequences (views, not copies)
1010+ # We perform the comparison on the ID tensors (int64), which is cheap
1011+ trimmed_old = ids_in_cache [i , o_start :]
1012+ trimmed_new = new_ids [i , n_start :]
1013+
1014+ min_len = min (len (trimmed_old ), len (trimmed_new ))
1015+
1016+ # Compare only up to min_len
1017+ # Using .ne() (not equal) and finding the first true is faster than checks
1018+ if min_len == 0 :
1019+ copy_len = 0
1020+ else :
1021+ # Find mismatch: (a != b)
1022+ mismatch = trimmed_old [:min_len ].ne (trimmed_new [:min_len ])
1023+ if not mismatch .any ():
1024+ copy_len = min_len
1025+ else :
1026+ # argmax on boolean gives index of first True
1027+ copy_len = mismatch .int ().argmax ().item ()
1028+
1029+ if copy_len > 0 :
1030+ # Define the slice objects now so we don't recreate them 32 times
1031+ src_slice = slice (o_start , o_start + copy_len )
1032+ # You align to the right (-length:)
1033+ dst_slice = slice (- copy_len , None )
1034+ copy_instructions .append ((i , src_slice , dst_slice ))
1035+
1036+ # 3. Apply changes to all layers using per-layer align method
1037+ for layer in self .layers :
1038+ layer .align (S_new , copy_instructions )
1039+
1040+ if return_new_ids_in_cache :
1041+ new_input_ids_in_cache = ids_in_cache .new_zeros ((B , S_new ))
1042+ # Execute the copy instructions for input IDs
1043+ for i , src_slice , dst_slice in copy_instructions :
1044+ new_input_ids_in_cache [i , dst_slice ] = ids_in_cache [i , src_slice ]
1045+ return new_input_ids_in_cache
1046+
1047+ def compress_and_repad_cache (self , padding_mask ):
1048+ for layer in self .layers :
1049+ layer .compress_and_repad_cache (padding_mask )
1050+
8941051
8951052class DynamicCache (Cache ):
8961053 """
@@ -1277,6 +1434,35 @@ def batch_select_indices(self, indices: torch.Tensor):
12771434 self .self_attention_cache .batch_select_indices (indices )
12781435 self .cross_attention_cache .batch_select_indices (indices )
12791436
1437+ def align (
1438+ self ,
1439+ new_ids : torch .LongTensor ,
1440+ ids_in_cache : torch .LongTensor ,
1441+ pad_token_id : int ,
1442+ return_new_ids_in_cache : bool = False ,
1443+ ):
1444+ """
1445+ Align the cache when input sequences change (e.g., when batching different sequences together).
1446+ This aligns both self-attention and cross-attention caches.
1447+
1448+ Args:
1449+ new_ids (`torch.LongTensor`): The new input IDs after batching changes.
1450+ ids_in_cache (`torch.LongTensor`): The input IDs that were used to build the current cache.
1451+ pad_token_id (`int`): The padding token ID.
1452+ return_new_ids_in_cache (`bool`, *optional*, defaults to `False`): Whether to return the aligned input IDs.
1453+
1454+ Returns:
1455+ `None` if `return_new_ids_in_cache=False`, otherwise the aligned input IDs tensor.
1456+ """
1457+ if return_new_ids_in_cache :
1458+ aligned_ids = self .self_attention_cache .align (new_ids , ids_in_cache , pad_token_id , return_new_ids_in_cache )
1459+ return aligned_ids
1460+ else :
1461+ self .self_attention_cache .align (new_ids , ids_in_cache , pad_token_id , return_new_ids_in_cache )
1462+
1463+ def compress_and_repad_cache (self , padding_mask ):
1464+ self .self_attention_cache .compress_and_repad_cache (padding_mask )
1465+
12801466 def get_max_cache_shape (self ) -> int :
12811467 """Returns the maximum sequence length (i.e. max capacity) of the cache object"""
12821468 return self .self_attention_cache .get_max_cache_shape ()
0 commit comments