Skip to content

Commit 7a6001c

Browse files
fixup
1 parent d0e25fe commit 7a6001c

File tree

4 files changed

+181
-106
lines changed

4 files changed

+181
-106
lines changed

src/transformers/generation/candidate_generator.py

Lines changed: 84 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor,
5858
f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`."
5959
)
6060

61-
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int, assistant_used: bool = True):
61+
def update_candidate_strategy(
62+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int, assistant_used: bool = True
63+
):
6264
"""
6365
Updates the candidate generation strategy based on the outcomes.
6466
@@ -199,7 +201,9 @@ def __init__(
199201
self.matches = []
200202
self.clean_probs = []
201203

202-
def get_candidates(self, input_ids: torch.LongTensor, assistant_ids_in_cache: torch.LongTensor = None) -> tuple[torch.LongTensor, torch.FloatTensor | None]:
204+
def get_candidates(
205+
self, input_ids: torch.LongTensor, assistant_ids_in_cache: torch.LongTensor = None
206+
) -> tuple[torch.LongTensor, torch.FloatTensor | None]:
203207
"""
204208
Fetches the candidates to be tried for the current input.
205209
@@ -224,7 +228,9 @@ def get_candidates(self, input_ids: torch.LongTensor, assistant_ids_in_cache: to
224228
candidate_ids, candidate_logits = self._generate_candidates(generation_args)
225229
return candidate_ids, candidate_logits
226230

227-
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int, assistant_used: bool = True):
231+
def update_candidate_strategy(
232+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int, assistant_used: bool = True
233+
):
228234
"""
229235
Updates the candidate generation strategy based on the outcomes.
230236
@@ -239,21 +245,20 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F
239245
If `int`, assumes `batch_size=1` for backward compatibility.
240246
assistant_used (`bool`):
241247
Whether the assistant was used to generate the candidates. Assistant was not used if max_new_tokens is 0.
242-
"""
248+
"""
243249
# Handle backward compatibility: convert int to tensor
244250
if isinstance(num_matches, int):
245251
num_matches = torch.tensor([num_matches], device=input_ids.device)
246-
252+
247253
batch_size = input_ids.shape[0]
248-
254+
249255
# Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
250256
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
251257
# cost of forecasting incorrect assistant tokens.
252258
if self.assistant_model.generation_config.num_assistant_tokens_schedule in {
253259
"heuristic",
254260
"heuristic_transient",
255261
}:
256-
257262
# For batch processing, we can use different strategies:
258263
# Option 1: Use average matches across batch
259264
avg_matches = num_matches.float().mean().item()
@@ -268,7 +273,7 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F
268273
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)
269274

270275
# The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes.
271-
# The costs are estimated based on the ROC curve, which considers the probability of the draft token and its match with the target.
276+
# The costs are estimated based on the ROC curve, which considers the probability of the draft token and its match with the target.
272277
# A cost of 25% is assigned to false positives and 75% to false negatives.
273278
# This adaptation is not compatible with UAG, as it relies on the number of matched tokens based on the draft vocabulary, which is unavailable in UAG.
274279
if (
@@ -287,13 +292,13 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F
287292
# this means we reject a token.
288293
item_matches.append(0)
289294
# taking only the relevant probabilities. for all the accepted tokens and the first rejected token.
290-
self.clean_probs.extend([self.probs[len(self.matches)][:len(item_matches)]])
295+
self.clean_probs.extend([self.probs[len(self.matches)][: len(item_matches)]])
291296
self.matches.extend([item_matches])
292-
297+
293298
assert len(self.matches) == len(self.clean_probs), "matches and probs must have the same length"
294299
clean_matches = np.concatenate(self.matches)
295300
clean_probs = np.concatenate(self.clean_probs)
296-
301+
297302
# calculate ROC curve and update threshold if we have enough samples
298303
if (
299304
len(clean_probs) > 5 and {0, 1}.issubset(clean_matches)
@@ -318,20 +323,29 @@ def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> tuple[int, int]:
318323
return min_new_tokens, max_new_tokens
319324

320325
def _update_past_and_masks(
321-
self, input_ids: torch.LongTensor, remove_from_pkv: int = 0, num_added_tokens: int = 1, assistant_ids_in_cache: torch.LongTensor = None
326+
self,
327+
input_ids: torch.LongTensor,
328+
remove_from_pkv: int = 0,
329+
num_added_tokens: int = 1,
330+
assistant_ids_in_cache: torch.LongTensor = None,
322331
) -> bool:
323332
"""Update past key values and attention masks for subsequent generation rounds."""
324333
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
325334
if has_past_key_values:
326335
new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv
327336
current_cache = self.assistant_kwargs["past_key_values"]
328-
if (batch_size:=input_ids.shape[0]) > 1:
329-
self.assistant_kwargs["past_key_values"] = align_cache(current_cache, input_ids, assistant_ids_in_cache,
330-
self.generation_config.pad_token_id)
337+
if input_ids.shape[0] > 1:
338+
self.assistant_kwargs["past_key_values"] = align_cache(
339+
current_cache, input_ids, assistant_ids_in_cache, self.generation_config.pad_token_id
340+
)
331341
else:
332342
self.assistant_kwargs["past_key_values"].crop(new_cache_size - num_added_tokens)
333343
self.assistant_kwargs = _prepare_attention_mask(
334-
self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder, input_ids, self.generation_config.pad_token_id
344+
self.assistant_kwargs,
345+
input_ids.shape[-1],
346+
self.assistant_model.config.is_encoder_decoder,
347+
input_ids,
348+
self.generation_config.pad_token_id,
335349
)
336350
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1])
337351

@@ -355,15 +369,23 @@ def _generate_candidates(self, generation_args: dict) -> tuple[torch.LongTensor,
355369
"""Generate candidate sequences using the assistant model."""
356370
assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs)
357371
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
358-
candidate_logits = torch.stack(assistant_output.scores, dim=1) # shape: (batch_size, candidate_length, vocab_size)
372+
candidate_logits = torch.stack(
373+
assistant_output.scores, dim=1
374+
) # shape: (batch_size, candidate_length, vocab_size)
359375
if (
360376
is_sklearn_available()
361377
and self.assistant_model.generation_config.assistant_confidence_threshold
362378
and type(self) is AssistedCandidateGenerator
363379
):
364-
scores_softmax = torch.softmax(candidate_logits, dim=-1) # shape: (batch_size, candidate_length, vocab_size)
365-
ids = assistant_output.sequences[:, -len(assistant_output.scores):] # shape: (batch_size, candidate_length)
366-
p = torch.gather(scores_softmax, dim=-1, index=ids.unsqueeze(-1)).squeeze(-1) # shape: (batch_size, candidate_length)
380+
scores_softmax = torch.softmax(
381+
candidate_logits, dim=-1
382+
) # shape: (batch_size, candidate_length, vocab_size)
383+
ids = assistant_output.sequences[
384+
:, -len(assistant_output.scores) :
385+
] # shape: (batch_size, candidate_length)
386+
p = torch.gather(scores_softmax, dim=-1, index=ids.unsqueeze(-1)).squeeze(
387+
-1
388+
) # shape: (batch_size, candidate_length)
367389
self.probs.extend(p.tolist())
368390
candidate_ids = assistant_output.sequences
369391
return candidate_ids, candidate_logits
@@ -530,7 +552,9 @@ def convert_source_tokens_to_target_tokens(
530552
dest_ids = destination_tokenizer(text, add_special_tokens=True, return_tensors="pt")["input_ids"]
531553
return dest_ids.to(input_ids.device)
532554

533-
def get_candidates(self, input_ids: torch.LongTensor, assistant_ids_in_cache: torch.LongTensor = None) -> tuple[torch.LongTensor, torch.FloatTensor | None]:
555+
def get_candidates(
556+
self, input_ids: torch.LongTensor, assistant_ids_in_cache: torch.LongTensor = None
557+
) -> tuple[torch.LongTensor, torch.FloatTensor | None]:
534558
"""
535559
Fetches the candidates to be tried for the current input.
536560
@@ -555,7 +579,9 @@ def get_candidates(self, input_ids: torch.LongTensor, assistant_ids_in_cache: to
555579

556580
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - assistant_input_ids.shape[-1]), 0)
557581

558-
self._update_past_and_masks(assistant_input_ids, remove_from_pkv, assistant_ids_in_cache=assistant_ids_in_cache)
582+
self._update_past_and_masks(
583+
assistant_input_ids, remove_from_pkv, assistant_ids_in_cache=assistant_ids_in_cache
584+
)
559585
generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens)
560586
self.assistant_kwargs.pop("attention_mask", None)
561587

@@ -955,7 +981,9 @@ def __init__(
955981
self._target_seq_len_with_candidates: int = 0
956982
self._prev_assistant_ids: torch.LongTensor | None = None
957983

958-
def get_candidates(self, input_ids: torch.LongTensor, assistant_ids_in_cache: torch.LongTensor = None) -> tuple[torch.LongTensor, torch.FloatTensor | None]:
984+
def get_candidates(
985+
self, input_ids: torch.LongTensor, assistant_ids_in_cache: torch.LongTensor = None
986+
) -> tuple[torch.LongTensor, torch.FloatTensor | None]:
959987
"""
960988
Simplified version of get_candidates that uses the translator cache for token conversion.
961989
"""
@@ -966,7 +994,9 @@ def get_candidates(self, input_ids: torch.LongTensor, assistant_ids_in_cache: to
966994
if max_new_tokens == 0:
967995
return input_ids, None
968996

969-
self._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens, assistant_ids_in_cache=assistant_ids_in_cache)
997+
self._update_past_and_masks(
998+
assistant_input_ids, num_added_tokens=num_added_tokens, assistant_ids_in_cache=assistant_ids_in_cache
999+
)
9701000
generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens)
9711001

9721002
# Ensure scores are returned
@@ -987,7 +1017,12 @@ def get_candidates(self, input_ids: torch.LongTensor, assistant_ids_in_cache: to
9871017

9881018
return target_candidate_ids, target_candidate_logits
9891019

990-
def _update_past_and_masks(self, assistant_input_ids: torch.LongTensor, num_added_tokens: int = 1, assistant_ids_in_cache: torch.LongTensor = None) -> bool:
1020+
def _update_past_and_masks(
1021+
self,
1022+
assistant_input_ids: torch.LongTensor,
1023+
num_added_tokens: int = 1,
1024+
assistant_ids_in_cache: torch.LongTensor = None,
1025+
) -> bool:
9911026
if self._prev_assistant_ids is None:
9921027
# Prepare attention mask for the first generation.
9931028
# For subsequent generations, the attention mask is updated in super()_update_past_and_masks.
@@ -1175,7 +1210,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor,
11751210
# assisted_generation expects logits as well, but we don't have those here, so returning None
11761211
return candidate_input_ids, None
11771212

1178-
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int, assistant_used: bool = True):
1213+
def update_candidate_strategy(
1214+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int, assistant_used: bool = True
1215+
):
11791216
"""
11801217
Updates the candidate generation strategy based on the outcomes.
11811218
@@ -1250,7 +1287,13 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor,
12501287
return candidate_ids, candidate_logits
12511288

12521289

1253-
def _prepare_attention_mask(model_kwargs: dict[str, Any], new_length: int, is_encoder_decoder: bool, input_ids: torch.LongTensor = None, pad_token_id: int = None) -> dict[str, Any]:
1290+
def _prepare_attention_mask(
1291+
model_kwargs: dict[str, Any],
1292+
new_length: int,
1293+
is_encoder_decoder: bool,
1294+
input_ids: torch.LongTensor | None = None,
1295+
pad_token_id: int | None = None,
1296+
) -> dict[str, Any]:
12541297
"""Expands or crops the model's mask for decoding purposes, to the defined length"""
12551298

12561299
mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
@@ -1261,7 +1304,7 @@ def _prepare_attention_mask(model_kwargs: dict[str, Any], new_length: int, is_en
12611304
mask_length_diff = new_length - mask.shape[1]
12621305
if input_ids is not None and pad_token_id is not None:
12631306
model_kwargs[mask_key] = (input_ids != pad_token_id).to(mask.dtype)
1264-
elif mask_length_diff < 0: # not sure when we get into this case
1307+
elif mask_length_diff < 0: # not sure when we get into this case
12651308
model_kwargs[mask_key] = mask[:, :mask_length_diff]
12661309
elif mask_length_diff > 0:
12671310
model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)
@@ -1307,38 +1350,38 @@ def _prepare_token_type_ids(model_kwargs: dict[str, Any], new_length: int) -> di
13071350
def align_cache(cache, new_ids, assistant_ids_in_cache, pad_token_id, apply_same_transform_on_old_ids: bool = False):
13081351
# 1. Setup metadata (Shape: [Batch, Heads, Sequence_Length, Dimension])
13091352
# We access the first layer just to get shapes and device
1310-
1353+
13111354
ref_layer = cache.layers[0]
13121355
old_ids = assistant_ids_in_cache
13131356
B, H, S_old, D = ref_layer.keys.shape
1314-
S_new = new_ids.shape[1] - 1 # Preserving your original sizing logic
1315-
1357+
S_new = new_ids.shape[1] - 1 # Preserving your original sizing logic
1358+
13161359
# 2. Pre-calculate "What to copy" for the whole batch ONCE.
13171360
# This removes the logic calculation from the inner loop (32x-80x speedup on logic)
1318-
1361+
13191362
# Find start indices (Vectorized)
1320-
# Note: sum() assumes left-padding only.
1363+
# Note: sum() assumes left-padding only.
13211364
old_start_indices = (old_ids == pad_token_id).sum(dim=1)
13221365
new_start_indices = (new_ids == pad_token_id).sum(dim=1)
1323-
1366+
13241367
# We will store the copy instructions here to apply to all layers later
13251368
# Format: List of tuples (batch_idx, source_slice, dest_slice, copy_len)
13261369
copy_instructions = []
1327-
1370+
13281371
# We still loop over batch (B), but only once, not B * Layers
13291372
for i in range(B):
13301373
# Identify the content without padding
13311374
# We use standard python slicing here as it's just index math, very fast
13321375
o_start = old_start_indices[i].item()
13331376
n_start = new_start_indices[i].item()
1334-
1377+
13351378
# Get the actual token sequences (views, not copies)
13361379
# We perform the comparison on the ID tensors (int64), which is cheap
13371380
trimmed_old = old_ids[i, o_start:]
13381381
trimmed_new = new_ids[i, n_start:]
1339-
1382+
13401383
min_len = min(len(trimmed_old), len(trimmed_new))
1341-
1384+
13421385
# Compare only up to min_len
13431386
# Using .ne() (not equal) and finding the first true is faster than checks
13441387
if min_len == 0:
@@ -1356,7 +1399,7 @@ def align_cache(cache, new_ids, assistant_ids_in_cache, pad_token_id, apply_same
13561399
# Define the slice objects now so we don't recreate them 32 times
13571400
src_slice = slice(o_start, o_start + copy_len)
13581401
# You align to the right (-length:)
1359-
dst_slice = slice(-copy_len, None)
1402+
dst_slice = slice(-copy_len, None)
13601403
copy_instructions.append((i, src_slice, dst_slice))
13611404

13621405
# 3. Apply changes to all layers
@@ -1373,7 +1416,7 @@ def align_cache(cache, new_ids, assistant_ids_in_cache, pad_token_id, apply_same
13731416
# Copy Keys/Values
13741417
new_keys[i, :, dst_slice] = layer.keys[i, :, src_slice]
13751418
new_values[i, :, dst_slice] = layer.values[i, :, src_slice]
1376-
1419+
13771420
if apply_same_transform_on_old_ids and new_input_ids_in_cache is not None:
13781421
new_input_ids_in_cache[i, dst_slice] = assistant_ids_in_cache[i, src_slice]
13791422
# Update the layer
@@ -1382,4 +1425,4 @@ def align_cache(cache, new_ids, assistant_ids_in_cache, pad_token_id, apply_same
13821425

13831426
if apply_same_transform_on_old_ids:
13841427
return cache, new_input_ids_in_cache
1385-
return cache
1428+
return cache

0 commit comments

Comments
 (0)