|
41 | 41 | class CandidateGenerator: |
42 | 42 | """Abstract base class for all candidate generators that can be applied during assisted generation.""" |
43 | 43 |
|
44 | | - def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, torch.FloatTensor | None]: |
| 44 | + def get_candidates( |
| 45 | + self, input_ids: torch.LongTensor, assistant_ids_in_cache: torch.LongTensor = None |
| 46 | + ) -> tuple[torch.LongTensor, torch.FloatTensor | None]: |
45 | 47 | """ |
46 | 48 | Fetches the candidates to be tried for the current input. |
47 | 49 |
|
48 | 50 | Args: |
49 | 51 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
50 | 52 | Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) |
| 53 | + assistant_ids_in_cache (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
51 | 54 |
|
52 | 55 | Return: |
53 | 56 | `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be |
@@ -248,6 +251,7 @@ def update_candidate_strategy( |
248 | 251 | """ |
249 | 252 | # Handle backward compatibility: convert int to tensor |
250 | 253 | if isinstance(num_matches, int): |
| 254 | + assert input_ids.shape[0] == 1, "num_matches should be a tensor of shape (batch_size,) when batch_size > 1" |
251 | 255 | num_matches = torch.tensor([num_matches], device=input_ids.device) |
252 | 256 |
|
253 | 257 | batch_size = input_ids.shape[0] |
@@ -332,13 +336,12 @@ def _update_past_and_masks( |
332 | 336 | """Update past key values and attention masks for subsequent generation rounds.""" |
333 | 337 | has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None |
334 | 338 | if has_past_key_values: |
335 | | - new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv |
336 | | - current_cache = self.assistant_kwargs["past_key_values"] |
337 | 339 | if input_ids.shape[0] > 1: |
338 | | - self.assistant_kwargs["past_key_values"] = align_cache( |
339 | | - current_cache, input_ids, assistant_ids_in_cache, self.generation_config.pad_token_id |
| 340 | + self.assistant_kwargs["past_key_values"].align( |
| 341 | + input_ids, assistant_ids_in_cache, self.generation_config.pad_token_id |
340 | 342 | ) |
341 | 343 | else: |
| 344 | + new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv |
342 | 345 | self.assistant_kwargs["past_key_values"].crop(new_cache_size - num_added_tokens) |
343 | 346 | self.assistant_kwargs = _prepare_attention_mask( |
344 | 347 | self.assistant_kwargs, |
@@ -561,7 +564,8 @@ def get_candidates( |
561 | 564 | Args: |
562 | 565 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
563 | 566 | Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) |
564 | | -
|
| 567 | + assistant_ids_in_cache (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| 568 | + Indices of input sequence tokens in the assistant vocabulary that are in the cache. |
565 | 569 | Return: |
566 | 570 | `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be |
567 | 571 | assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length, |
@@ -1116,13 +1120,17 @@ def __init__( |
1116 | 1120 | if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: |
1117 | 1121 | raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") |
1118 | 1122 |
|
1119 | | - def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, torch.FloatTensor | None]: |
| 1123 | + def get_candidates( |
| 1124 | + self, input_ids: torch.LongTensor, assistant_ids_in_cache: torch.LongTensor = None |
| 1125 | + ) -> tuple[torch.LongTensor, torch.FloatTensor | None]: |
1120 | 1126 | """ |
1121 | 1127 | Fetches the candidates to be tried for the current input. |
1122 | 1128 |
|
1123 | 1129 | Args: |
1124 | 1130 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
1125 | 1131 | Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) |
| 1132 | + assistant_ids_in_cache (`torch.LongTensor`, *optional*): |
| 1133 | + Assistant model input IDs that are already in the cache. Not used by prompt lookup decoding. |
1126 | 1134 |
|
1127 | 1135 | Return: |
1128 | 1136 | `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. |
@@ -1277,12 +1285,16 @@ def __init__( |
1277 | 1285 | self.assistant_early_exit = self.generation_config.assistant_early_exit |
1278 | 1286 | self.generation_config.assistant_early_exit = None |
1279 | 1287 |
|
1280 | | - def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, torch.FloatTensor | None]: |
| 1288 | + def get_candidates( |
| 1289 | + self, input_ids: torch.LongTensor, assistant_ids_in_cache: torch.LongTensor = None |
| 1290 | + ) -> tuple[torch.LongTensor, torch.FloatTensor | None]: |
1281 | 1291 | # Temporarily sets the number of hidden layers to the early exit value |
1282 | 1292 | base_model = getattr(self.assistant_model, self.assistant_model.base_model_prefix) |
1283 | 1293 | original_num_hidden_layers = base_model.config.num_hidden_layers |
1284 | 1294 | base_model.config.num_hidden_layers = self.assistant_early_exit |
1285 | | - candidate_ids, candidate_logits = super().get_candidates(input_ids) |
| 1295 | + candidate_ids, candidate_logits = super().get_candidates( |
| 1296 | + input_ids, assistant_ids_in_cache=assistant_ids_in_cache |
| 1297 | + ) |
1286 | 1298 | base_model.config.num_hidden_layers = original_num_hidden_layers |
1287 | 1299 | return candidate_ids, candidate_logits |
1288 | 1300 |
|
@@ -1345,84 +1357,3 @@ def _prepare_token_type_ids(model_kwargs: dict[str, Any], new_length: int) -> di |
1345 | 1357 | token_type_copies = final_token_type.repeat(1, type_length_diff) |
1346 | 1358 | model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1) |
1347 | 1359 | return model_kwargs |
1348 | | - |
1349 | | - |
1350 | | -def align_cache(cache, new_ids, assistant_ids_in_cache, pad_token_id, apply_same_transform_on_old_ids: bool = False): |
1351 | | - # 1. Setup metadata (Shape: [Batch, Heads, Sequence_Length, Dimension]) |
1352 | | - # We access the first layer just to get shapes and device |
1353 | | - |
1354 | | - ref_layer = cache.layers[0] |
1355 | | - old_ids = assistant_ids_in_cache |
1356 | | - B, H, S_old, D = ref_layer.keys.shape |
1357 | | - S_new = new_ids.shape[1] - 1 # Preserving your original sizing logic |
1358 | | - |
1359 | | - # 2. Pre-calculate "What to copy" for the whole batch ONCE. |
1360 | | - # This removes the logic calculation from the inner loop (32x-80x speedup on logic) |
1361 | | - |
1362 | | - # Find start indices (Vectorized) |
1363 | | - # Note: sum() assumes left-padding only. |
1364 | | - old_start_indices = (old_ids == pad_token_id).sum(dim=1) |
1365 | | - new_start_indices = (new_ids == pad_token_id).sum(dim=1) |
1366 | | - |
1367 | | - # We will store the copy instructions here to apply to all layers later |
1368 | | - # Format: List of tuples (batch_idx, source_slice, dest_slice, copy_len) |
1369 | | - copy_instructions = [] |
1370 | | - |
1371 | | - # We still loop over batch (B), but only once, not B * Layers |
1372 | | - for i in range(B): |
1373 | | - # Identify the content without padding |
1374 | | - # We use standard python slicing here as it's just index math, very fast |
1375 | | - o_start = old_start_indices[i].item() |
1376 | | - n_start = new_start_indices[i].item() |
1377 | | - |
1378 | | - # Get the actual token sequences (views, not copies) |
1379 | | - # We perform the comparison on the ID tensors (int64), which is cheap |
1380 | | - trimmed_old = old_ids[i, o_start:] |
1381 | | - trimmed_new = new_ids[i, n_start:] |
1382 | | - |
1383 | | - min_len = min(len(trimmed_old), len(trimmed_new)) |
1384 | | - |
1385 | | - # Compare only up to min_len |
1386 | | - # Using .ne() (not equal) and finding the first true is faster than checks |
1387 | | - if min_len == 0: |
1388 | | - copy_len = 0 |
1389 | | - else: |
1390 | | - # Find mismatch: (a != b) |
1391 | | - mismatch = trimmed_old[:min_len].ne(trimmed_new[:min_len]) |
1392 | | - if not mismatch.any(): |
1393 | | - copy_len = min_len |
1394 | | - else: |
1395 | | - # argmax on boolean gives index of first True |
1396 | | - copy_len = mismatch.int().argmax().item() |
1397 | | - |
1398 | | - if copy_len > 0: |
1399 | | - # Define the slice objects now so we don't recreate them 32 times |
1400 | | - src_slice = slice(o_start, o_start + copy_len) |
1401 | | - # You align to the right (-length:) |
1402 | | - dst_slice = slice(-copy_len, None) |
1403 | | - copy_instructions.append((i, src_slice, dst_slice)) |
1404 | | - |
1405 | | - # 3. Apply changes to all layers |
1406 | | - # We allocate new tensors and copy in bulk based on pre-calculated instructions |
1407 | | - new_input_ids_in_cache = None |
1408 | | - for layer in cache.layers: |
1409 | | - # Allocation (This is the heavy GPU/Memory op) |
1410 | | - new_keys = layer.keys.new_zeros((B, H, S_new, D)) |
1411 | | - new_values = layer.values.new_zeros((B, H, S_new, D)) |
1412 | | - if apply_same_transform_on_old_ids and new_input_ids_in_cache is None: |
1413 | | - new_input_ids_in_cache = assistant_ids_in_cache.new_zeros((B, S_new)) |
1414 | | - # Execute the pre-calculated copy instructions |
1415 | | - for i, src_slice, dst_slice in copy_instructions: |
1416 | | - # Copy Keys/Values |
1417 | | - new_keys[i, :, dst_slice] = layer.keys[i, :, src_slice] |
1418 | | - new_values[i, :, dst_slice] = layer.values[i, :, src_slice] |
1419 | | - |
1420 | | - if apply_same_transform_on_old_ids and new_input_ids_in_cache is not None: |
1421 | | - new_input_ids_in_cache[i, dst_slice] = assistant_ids_in_cache[i, src_slice] |
1422 | | - # Update the layer |
1423 | | - layer.keys = new_keys |
1424 | | - layer.values = new_values |
1425 | | - |
1426 | | - if apply_same_transform_on_old_ids: |
1427 | | - return cache, new_input_ids_in_cache |
1428 | | - return cache |
0 commit comments