@@ -58,7 +58,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor,
5858 f"{ self .__class__ } is an abstract class. Only classes inheriting this class can call `get_candidates`."
5959 )
6060
61- def update_candidate_strategy (self , input_ids : torch .LongTensor , scores : torch .FloatTensor , num_matches : int , assistant_used : bool = True ):
61+ def update_candidate_strategy (
62+ self , input_ids : torch .LongTensor , scores : torch .FloatTensor , num_matches : int , assistant_used : bool = True
63+ ):
6264 """
6365 Updates the candidate generation strategy based on the outcomes.
6466
@@ -199,7 +201,9 @@ def __init__(
199201 self .matches = []
200202 self .clean_probs = []
201203
202- def get_candidates (self , input_ids : torch .LongTensor , assistant_ids_in_cache : torch .LongTensor = None ) -> tuple [torch .LongTensor , torch .FloatTensor | None ]:
204+ def get_candidates (
205+ self , input_ids : torch .LongTensor , assistant_ids_in_cache : torch .LongTensor = None
206+ ) -> tuple [torch .LongTensor , torch .FloatTensor | None ]:
203207 """
204208 Fetches the candidates to be tried for the current input.
205209
@@ -224,7 +228,9 @@ def get_candidates(self, input_ids: torch.LongTensor, assistant_ids_in_cache: to
224228 candidate_ids , candidate_logits = self ._generate_candidates (generation_args )
225229 return candidate_ids , candidate_logits
226230
227- def update_candidate_strategy (self , input_ids : torch .LongTensor , scores : torch .FloatTensor , num_matches : int , assistant_used : bool = True ):
231+ def update_candidate_strategy (
232+ self , input_ids : torch .LongTensor , scores : torch .FloatTensor , num_matches : int , assistant_used : bool = True
233+ ):
228234 """
229235 Updates the candidate generation strategy based on the outcomes.
230236
@@ -239,21 +245,20 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F
239245 If `int`, assumes `batch_size=1` for backward compatibility.
240246 assistant_used (`bool`):
241247 Whether the assistant was used to generate the candidates. Assistant was not used if max_new_tokens is 0.
242- """
248+ """
243249 # Handle backward compatibility: convert int to tensor
244250 if isinstance (num_matches , int ):
245251 num_matches = torch .tensor ([num_matches ], device = input_ids .device )
246-
252+
247253 batch_size = input_ids .shape [0 ]
248-
254+
249255 # Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
250256 # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
251257 # cost of forecasting incorrect assistant tokens.
252258 if self .assistant_model .generation_config .num_assistant_tokens_schedule in {
253259 "heuristic" ,
254260 "heuristic_transient" ,
255261 }:
256-
257262 # For batch processing, we can use different strategies:
258263 # Option 1: Use average matches across batch
259264 avg_matches = num_matches .float ().mean ().item ()
@@ -268,7 +273,7 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F
268273 self .num_assistant_tokens = max (1.0 , self .num_assistant_tokens - 1.0 )
269274
270275 # The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes.
271- # The costs are estimated based on the ROC curve, which considers the probability of the draft token and its match with the target.
276+ # The costs are estimated based on the ROC curve, which considers the probability of the draft token and its match with the target.
272277 # A cost of 25% is assigned to false positives and 75% to false negatives.
273278 # This adaptation is not compatible with UAG, as it relies on the number of matched tokens based on the draft vocabulary, which is unavailable in UAG.
274279 if (
@@ -287,13 +292,13 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F
287292 # this means we reject a token.
288293 item_matches .append (0 )
289294 # taking only the relevant probabilities. for all the accepted tokens and the first rejected token.
290- self .clean_probs .extend ([self .probs [len (self .matches )][:len (item_matches )]])
295+ self .clean_probs .extend ([self .probs [len (self .matches )][: len (item_matches )]])
291296 self .matches .extend ([item_matches ])
292-
297+
293298 assert len (self .matches ) == len (self .clean_probs ), "matches and probs must have the same length"
294299 clean_matches = np .concatenate (self .matches )
295300 clean_probs = np .concatenate (self .clean_probs )
296-
301+
297302 # calculate ROC curve and update threshold if we have enough samples
298303 if (
299304 len (clean_probs ) > 5 and {0 , 1 }.issubset (clean_matches )
@@ -318,20 +323,29 @@ def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> tuple[int, int]:
318323 return min_new_tokens , max_new_tokens
319324
320325 def _update_past_and_masks (
321- self , input_ids : torch .LongTensor , remove_from_pkv : int = 0 , num_added_tokens : int = 1 , assistant_ids_in_cache : torch .LongTensor = None
326+ self ,
327+ input_ids : torch .LongTensor ,
328+ remove_from_pkv : int = 0 ,
329+ num_added_tokens : int = 1 ,
330+ assistant_ids_in_cache : torch .LongTensor = None ,
322331 ) -> bool :
323332 """Update past key values and attention masks for subsequent generation rounds."""
324333 has_past_key_values = self .assistant_kwargs .get ("past_key_values" , None ) is not None
325334 if has_past_key_values :
326335 new_cache_size = input_ids .shape [- 1 ] - 1 - remove_from_pkv
327336 current_cache = self .assistant_kwargs ["past_key_values" ]
328- if (batch_size := input_ids .shape [0 ]) > 1 :
329- self .assistant_kwargs ["past_key_values" ] = align_cache (current_cache , input_ids , assistant_ids_in_cache ,
330- self .generation_config .pad_token_id )
337+ if input_ids .shape [0 ] > 1 :
338+ self .assistant_kwargs ["past_key_values" ] = align_cache (
339+ current_cache , input_ids , assistant_ids_in_cache , self .generation_config .pad_token_id
340+ )
331341 else :
332342 self .assistant_kwargs ["past_key_values" ].crop (new_cache_size - num_added_tokens )
333343 self .assistant_kwargs = _prepare_attention_mask (
334- self .assistant_kwargs , input_ids .shape [- 1 ], self .assistant_model .config .is_encoder_decoder , input_ids , self .generation_config .pad_token_id
344+ self .assistant_kwargs ,
345+ input_ids .shape [- 1 ],
346+ self .assistant_model .config .is_encoder_decoder ,
347+ input_ids ,
348+ self .generation_config .pad_token_id ,
335349 )
336350 self .assistant_kwargs = _prepare_token_type_ids (self .assistant_kwargs , input_ids .shape [- 1 ])
337351
@@ -355,15 +369,23 @@ def _generate_candidates(self, generation_args: dict) -> tuple[torch.LongTensor,
355369 """Generate candidate sequences using the assistant model."""
356370 assistant_output = self .assistant_model .generate (** generation_args , ** self .assistant_kwargs )
357371 self .assistant_kwargs ["past_key_values" ] = assistant_output .past_key_values
358- candidate_logits = torch .stack (assistant_output .scores , dim = 1 ) # shape: (batch_size, candidate_length, vocab_size)
372+ candidate_logits = torch .stack (
373+ assistant_output .scores , dim = 1
374+ ) # shape: (batch_size, candidate_length, vocab_size)
359375 if (
360376 is_sklearn_available ()
361377 and self .assistant_model .generation_config .assistant_confidence_threshold
362378 and type (self ) is AssistedCandidateGenerator
363379 ):
364- scores_softmax = torch .softmax (candidate_logits , dim = - 1 ) # shape: (batch_size, candidate_length, vocab_size)
365- ids = assistant_output .sequences [:, - len (assistant_output .scores ):] # shape: (batch_size, candidate_length)
366- p = torch .gather (scores_softmax , dim = - 1 , index = ids .unsqueeze (- 1 )).squeeze (- 1 ) # shape: (batch_size, candidate_length)
380+ scores_softmax = torch .softmax (
381+ candidate_logits , dim = - 1
382+ ) # shape: (batch_size, candidate_length, vocab_size)
383+ ids = assistant_output .sequences [
384+ :, - len (assistant_output .scores ) :
385+ ] # shape: (batch_size, candidate_length)
386+ p = torch .gather (scores_softmax , dim = - 1 , index = ids .unsqueeze (- 1 )).squeeze (
387+ - 1
388+ ) # shape: (batch_size, candidate_length)
367389 self .probs .extend (p .tolist ())
368390 candidate_ids = assistant_output .sequences
369391 return candidate_ids , candidate_logits
@@ -530,7 +552,9 @@ def convert_source_tokens_to_target_tokens(
530552 dest_ids = destination_tokenizer (text , add_special_tokens = True , return_tensors = "pt" )["input_ids" ]
531553 return dest_ids .to (input_ids .device )
532554
533- def get_candidates (self , input_ids : torch .LongTensor , assistant_ids_in_cache : torch .LongTensor = None ) -> tuple [torch .LongTensor , torch .FloatTensor | None ]:
555+ def get_candidates (
556+ self , input_ids : torch .LongTensor , assistant_ids_in_cache : torch .LongTensor = None
557+ ) -> tuple [torch .LongTensor , torch .FloatTensor | None ]:
534558 """
535559 Fetches the candidates to be tried for the current input.
536560
@@ -555,7 +579,9 @@ def get_candidates(self, input_ids: torch.LongTensor, assistant_ids_in_cache: to
555579
556580 min_new_tokens = max (min (max_new_tokens , self .main_model_min_length - assistant_input_ids .shape [- 1 ]), 0 )
557581
558- self ._update_past_and_masks (assistant_input_ids , remove_from_pkv , assistant_ids_in_cache = assistant_ids_in_cache )
582+ self ._update_past_and_masks (
583+ assistant_input_ids , remove_from_pkv , assistant_ids_in_cache = assistant_ids_in_cache
584+ )
559585 generation_args = self ._prepare_generation_args (assistant_input_ids , min_new_tokens , max_new_tokens )
560586 self .assistant_kwargs .pop ("attention_mask" , None )
561587
@@ -955,7 +981,9 @@ def __init__(
955981 self ._target_seq_len_with_candidates : int = 0
956982 self ._prev_assistant_ids : torch .LongTensor | None = None
957983
958- def get_candidates (self , input_ids : torch .LongTensor , assistant_ids_in_cache : torch .LongTensor = None ) -> tuple [torch .LongTensor , torch .FloatTensor | None ]:
984+ def get_candidates (
985+ self , input_ids : torch .LongTensor , assistant_ids_in_cache : torch .LongTensor = None
986+ ) -> tuple [torch .LongTensor , torch .FloatTensor | None ]:
959987 """
960988 Simplified version of get_candidates that uses the translator cache for token conversion.
961989 """
@@ -966,7 +994,9 @@ def get_candidates(self, input_ids: torch.LongTensor, assistant_ids_in_cache: to
966994 if max_new_tokens == 0 :
967995 return input_ids , None
968996
969- self ._update_past_and_masks (assistant_input_ids , num_added_tokens = num_added_tokens , assistant_ids_in_cache = assistant_ids_in_cache )
997+ self ._update_past_and_masks (
998+ assistant_input_ids , num_added_tokens = num_added_tokens , assistant_ids_in_cache = assistant_ids_in_cache
999+ )
9701000 generation_args = self ._prepare_generation_args (assistant_input_ids , min_new_tokens , max_new_tokens )
9711001
9721002 # Ensure scores are returned
@@ -987,7 +1017,12 @@ def get_candidates(self, input_ids: torch.LongTensor, assistant_ids_in_cache: to
9871017
9881018 return target_candidate_ids , target_candidate_logits
9891019
990- def _update_past_and_masks (self , assistant_input_ids : torch .LongTensor , num_added_tokens : int = 1 , assistant_ids_in_cache : torch .LongTensor = None ) -> bool :
1020+ def _update_past_and_masks (
1021+ self ,
1022+ assistant_input_ids : torch .LongTensor ,
1023+ num_added_tokens : int = 1 ,
1024+ assistant_ids_in_cache : torch .LongTensor = None ,
1025+ ) -> bool :
9911026 if self ._prev_assistant_ids is None :
9921027 # Prepare attention mask for the first generation.
9931028 # For subsequent generations, the attention mask is updated in super()_update_past_and_masks.
@@ -1175,7 +1210,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor,
11751210 # assisted_generation expects logits as well, but we don't have those here, so returning None
11761211 return candidate_input_ids , None
11771212
1178- def update_candidate_strategy (self , input_ids : torch .LongTensor , scores : torch .FloatTensor , num_matches : int , assistant_used : bool = True ):
1213+ def update_candidate_strategy (
1214+ self , input_ids : torch .LongTensor , scores : torch .FloatTensor , num_matches : int , assistant_used : bool = True
1215+ ):
11791216 """
11801217 Updates the candidate generation strategy based on the outcomes.
11811218
@@ -1250,7 +1287,13 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor,
12501287 return candidate_ids , candidate_logits
12511288
12521289
1253- def _prepare_attention_mask (model_kwargs : dict [str , Any ], new_length : int , is_encoder_decoder : bool , input_ids : torch .LongTensor = None , pad_token_id : int = None ) -> dict [str , Any ]:
1290+ def _prepare_attention_mask (
1291+ model_kwargs : dict [str , Any ],
1292+ new_length : int ,
1293+ is_encoder_decoder : bool ,
1294+ input_ids : torch .LongTensor | None = None ,
1295+ pad_token_id : int | None = None ,
1296+ ) -> dict [str , Any ]:
12541297 """Expands or crops the model's mask for decoding purposes, to the defined length"""
12551298
12561299 mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
@@ -1261,7 +1304,7 @@ def _prepare_attention_mask(model_kwargs: dict[str, Any], new_length: int, is_en
12611304 mask_length_diff = new_length - mask .shape [1 ]
12621305 if input_ids is not None and pad_token_id is not None :
12631306 model_kwargs [mask_key ] = (input_ids != pad_token_id ).to (mask .dtype )
1264- elif mask_length_diff < 0 : # not sure when we get into this case
1307+ elif mask_length_diff < 0 : # not sure when we get into this case
12651308 model_kwargs [mask_key ] = mask [:, :mask_length_diff ]
12661309 elif mask_length_diff > 0 :
12671310 model_kwargs [mask_key ] = torch .cat ([mask , mask .new_ones ((mask .shape [0 ], mask_length_diff ))], dim = - 1 )
@@ -1307,38 +1350,38 @@ def _prepare_token_type_ids(model_kwargs: dict[str, Any], new_length: int) -> di
13071350def align_cache (cache , new_ids , assistant_ids_in_cache , pad_token_id , apply_same_transform_on_old_ids : bool = False ):
13081351 # 1. Setup metadata (Shape: [Batch, Heads, Sequence_Length, Dimension])
13091352 # We access the first layer just to get shapes and device
1310-
1353+
13111354 ref_layer = cache .layers [0 ]
13121355 old_ids = assistant_ids_in_cache
13131356 B , H , S_old , D = ref_layer .keys .shape
1314- S_new = new_ids .shape [1 ] - 1 # Preserving your original sizing logic
1315-
1357+ S_new = new_ids .shape [1 ] - 1 # Preserving your original sizing logic
1358+
13161359 # 2. Pre-calculate "What to copy" for the whole batch ONCE.
13171360 # This removes the logic calculation from the inner loop (32x-80x speedup on logic)
1318-
1361+
13191362 # Find start indices (Vectorized)
1320- # Note: sum() assumes left-padding only.
1363+ # Note: sum() assumes left-padding only.
13211364 old_start_indices = (old_ids == pad_token_id ).sum (dim = 1 )
13221365 new_start_indices = (new_ids == pad_token_id ).sum (dim = 1 )
1323-
1366+
13241367 # We will store the copy instructions here to apply to all layers later
13251368 # Format: List of tuples (batch_idx, source_slice, dest_slice, copy_len)
13261369 copy_instructions = []
1327-
1370+
13281371 # We still loop over batch (B), but only once, not B * Layers
13291372 for i in range (B ):
13301373 # Identify the content without padding
13311374 # We use standard python slicing here as it's just index math, very fast
13321375 o_start = old_start_indices [i ].item ()
13331376 n_start = new_start_indices [i ].item ()
1334-
1377+
13351378 # Get the actual token sequences (views, not copies)
13361379 # We perform the comparison on the ID tensors (int64), which is cheap
13371380 trimmed_old = old_ids [i , o_start :]
13381381 trimmed_new = new_ids [i , n_start :]
1339-
1382+
13401383 min_len = min (len (trimmed_old ), len (trimmed_new ))
1341-
1384+
13421385 # Compare only up to min_len
13431386 # Using .ne() (not equal) and finding the first true is faster than checks
13441387 if min_len == 0 :
@@ -1356,7 +1399,7 @@ def align_cache(cache, new_ids, assistant_ids_in_cache, pad_token_id, apply_same
13561399 # Define the slice objects now so we don't recreate them 32 times
13571400 src_slice = slice (o_start , o_start + copy_len )
13581401 # You align to the right (-length:)
1359- dst_slice = slice (- copy_len , None )
1402+ dst_slice = slice (- copy_len , None )
13601403 copy_instructions .append ((i , src_slice , dst_slice ))
13611404
13621405 # 3. Apply changes to all layers
@@ -1373,7 +1416,7 @@ def align_cache(cache, new_ids, assistant_ids_in_cache, pad_token_id, apply_same
13731416 # Copy Keys/Values
13741417 new_keys [i , :, dst_slice ] = layer .keys [i , :, src_slice ]
13751418 new_values [i , :, dst_slice ] = layer .values [i , :, src_slice ]
1376-
1419+
13771420 if apply_same_transform_on_old_ids and new_input_ids_in_cache is not None :
13781421 new_input_ids_in_cache [i , dst_slice ] = assistant_ids_in_cache [i , src_slice ]
13791422 # Update the layer
@@ -1382,4 +1425,4 @@ def align_cache(cache, new_ids, assistant_ids_in_cache, pad_token_id, apply_same
13821425
13831426 if apply_same_transform_on_old_ids :
13841427 return cache , new_input_ids_in_cache
1385- return cache
1428+ return cache
0 commit comments