Skip to content

Commit a48d68c

Browse files
Fix some models cache initialization (#42586)
* Create cache when training in case generate needs being called * Align modular * fixes * cohere * fix modular * fix * review --------- Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
1 parent 9e82c77 commit a48d68c

File tree

10 files changed

+70
-448
lines changed

10 files changed

+70
-448
lines changed

src/transformers/models/cohere2/modeling_cohere2.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from ...cache_utils import Cache, DynamicCache
3030
from ...generation import GenerationMixin
3131
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
32-
from ...modeling_flash_attention_utils import FlashAttentionKwargs
3332
from ...modeling_layers import GradientCheckpointingLayer
3433
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
3534
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
@@ -233,7 +232,7 @@ def forward(
233232
attention_mask: Optional[torch.Tensor],
234233
past_key_values: Optional[Cache] = None,
235234
cache_position: Optional[torch.LongTensor] = None,
236-
**kwargs: Unpack[FlashAttentionKwargs],
235+
**kwargs: Unpack[TransformersKwargs],
237236
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
238237
input_shape = hidden_states.shape[:-1]
239238
hidden_shape = (*input_shape, -1, self.head_dim)
@@ -304,7 +303,7 @@ def forward(
304303
past_key_values: Optional[Cache] = None,
305304
use_cache: Optional[bool] = False,
306305
cache_position: Optional[torch.LongTensor] = None,
307-
**kwargs: Unpack[FlashAttentionKwargs],
306+
**kwargs: Unpack[TransformersKwargs],
308307
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
309308
"""
310309
Args:
@@ -398,7 +397,7 @@ def forward(
398397
if inputs_embeds is None:
399398
inputs_embeds = self.embed_tokens(input_ids)
400399

401-
if use_cache and past_key_values is None and not self.training:
400+
if use_cache and past_key_values is None:
402401
past_key_values = DynamicCache(config=self.config)
403402

404403
if cache_position is None:

src/transformers/models/cohere2/modular_cohere2.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from ...cache_utils import Cache, DynamicCache
2323
from ...configuration_utils import PreTrainedConfig, layer_type_validation
2424
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
25-
from ...modeling_flash_attention_utils import FlashAttentionKwargs
2625
from ...modeling_outputs import BaseModelOutputWithPast
2726
from ...modeling_rope_utils import (
2827
RopeParameters,
@@ -271,7 +270,7 @@ def forward(
271270
attention_mask: Optional[torch.Tensor],
272271
past_key_values: Optional[Cache] = None,
273272
cache_position: Optional[torch.LongTensor] = None,
274-
**kwargs: Unpack[FlashAttentionKwargs],
273+
**kwargs: Unpack[TransformersKwargs],
275274
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
276275
input_shape = hidden_states.shape[:-1]
277276
hidden_shape = (*input_shape, -1, self.head_dim)
@@ -322,7 +321,7 @@ def forward(
322321
past_key_values: Optional[Cache] = None,
323322
use_cache: Optional[bool] = False,
324323
cache_position: Optional[torch.LongTensor] = None,
325-
**kwargs: Unpack[FlashAttentionKwargs],
324+
**kwargs: Unpack[TransformersKwargs],
326325
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
327326
residual = hidden_states
328327
hidden_states = self.input_layernorm(hidden_states)
@@ -367,7 +366,7 @@ def forward(
367366
if inputs_embeds is None:
368367
inputs_embeds = self.embed_tokens(input_ids)
369368

370-
if use_cache and past_key_values is None and not self.training:
369+
if use_cache and past_key_values is None:
371370
past_key_values = DynamicCache(config=self.config)
372371

373372
if cache_position is None:

src/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 10 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from ...generation import GenerationMixin
3434
from ...integrations import use_kernel_func_from_hub
3535
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
36-
from ...modeling_flash_attention_utils import FlashAttentionKwargs
3736
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
3837
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
3938
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
@@ -347,7 +346,7 @@ def forward(
347346
attention_mask: Optional[torch.Tensor] = None,
348347
past_key_values: Optional[Cache] = None,
349348
cache_position: Optional[torch.LongTensor] = None,
350-
**kwargs: Unpack[FlashAttentionKwargs],
349+
**kwargs: Unpack[TransformersKwargs],
351350
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
352351
input_shape = hidden_states.shape[:-1]
353352
hidden_shape = (*input_shape, -1, self.head_dim)
@@ -409,23 +408,19 @@ def forward(
409408
attention_mask: Optional[torch.Tensor] = None,
410409
position_ids: Optional[torch.LongTensor] = None,
411410
past_key_values: Optional[Cache] = None,
412-
output_attentions: Optional[bool] = False,
413-
use_cache: Optional[bool] = False,
414411
cache_position: Optional[torch.LongTensor] = None,
415-
**kwargs,
412+
**kwargs: Unpack[TransformersKwargs],
416413
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
417414
residual = hidden_states
418415

419416
hidden_states = self.input_layernorm(hidden_states)
420417

421-
hidden_states, self_attn_weights = self.self_attn(
418+
hidden_states, _ = self.self_attn(
422419
hidden_states=hidden_states,
423420
position_embeddings=position_embeddings,
424421
attention_mask=attention_mask,
425422
position_ids=position_ids,
426423
past_key_values=past_key_values,
427-
output_attentions=output_attentions,
428-
use_cache=use_cache,
429424
cache_position=cache_position,
430425
**kwargs,
431426
)
@@ -438,12 +433,7 @@ def forward(
438433
hidden_states = self.post_feedforward_layernorm(hidden_states)
439434
hidden_states = residual + hidden_states
440435

441-
outputs = (hidden_states,)
442-
443-
if output_attentions:
444-
outputs += (self_attn_weights,)
445-
446-
return outputs
436+
return hidden_states
447437

448438

449439
@auto_docstring
@@ -527,30 +517,16 @@ def forward(
527517
past_key_values: Optional[Cache] = None,
528518
inputs_embeds: Optional[torch.FloatTensor] = None,
529519
use_cache: Optional[bool] = None,
530-
output_attentions: Optional[bool] = None,
531-
output_hidden_states: Optional[bool] = None,
532520
cache_position: Optional[torch.LongTensor] = None,
533521
**kwargs: Unpack[TransformersKwargs],
534522
) -> BaseModelOutputWithPast:
535-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
536-
output_hidden_states = (
537-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
538-
)
539-
use_cache = use_cache if use_cache is not None else self.config.use_cache
540-
541523
if (input_ids is None) ^ (inputs_embeds is not None):
542524
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
543525

544-
if self.gradient_checkpointing and self.training and use_cache:
545-
logger.warning_once(
546-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
547-
)
548-
use_cache = False
549-
550526
if inputs_embeds is None:
551527
inputs_embeds = self.embed_tokens(input_ids)
552528

553-
if use_cache and past_key_values is None and not self.training:
529+
if use_cache and past_key_values is None:
554530
past_key_values = DynamicCache(config=self.config)
555531

556532
if cache_position is None:
@@ -591,41 +567,22 @@ def forward(
591567
for layer_type in self.config.layer_types:
592568
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
593569

594-
# decoder layers
595-
all_hidden_states = () if output_hidden_states else None
596-
all_self_attns = () if output_attentions else None
597-
598570
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
599-
if output_hidden_states:
600-
all_hidden_states += (hidden_states,)
601-
602-
layer_outputs = decoder_layer(
571+
hidden_states = decoder_layer(
603572
hidden_states,
604573
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
605574
position_embeddings=position_embeddings[decoder_layer.attention_type],
606575
position_ids=position_ids,
607576
past_key_values=past_key_values,
608-
output_attentions=output_attentions,
609-
use_cache=use_cache,
610577
cache_position=cache_position,
611578
**kwargs,
612579
)
613580

614-
hidden_states = layer_outputs[0]
615-
616-
if output_attentions:
617-
all_self_attns += (layer_outputs[1],)
618-
619581
hidden_states = self.norm(hidden_states)
620582

621-
if output_hidden_states:
622-
all_hidden_states += (hidden_states,)
623-
624583
return BaseModelOutputWithPast(
625584
last_hidden_state=hidden_states,
626585
past_key_values=past_key_values,
627-
hidden_states=all_hidden_states,
628-
attentions=all_self_attns,
629586
)
630587

631588

@@ -918,10 +875,7 @@ def forward(
918875
inputs_embeds: Optional[torch.FloatTensor] = None,
919876
labels: Optional[torch.LongTensor] = None,
920877
use_cache: Optional[bool] = None,
921-
output_attentions: Optional[bool] = None,
922-
output_hidden_states: Optional[bool] = None,
923-
return_dict: Optional[bool] = None,
924-
**lm_kwargs,
878+
**lm_kwargs: Unpack[TransformersKwargs],
925879
) -> Union[tuple, Gemma3ModelOutputWithPast]:
926880
r"""
927881
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -953,12 +907,6 @@ def forward(
953907
if (input_ids is None) ^ (inputs_embeds is not None):
954908
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
955909

956-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
957-
output_hidden_states = (
958-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
959-
)
960-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
961-
962910
# Replace image id with PAD if the image token if OOV, to avoid index-errors
963911
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
964912
special_image_mask = input_ids == self.config.image_token_id
@@ -1005,16 +953,14 @@ def forward(
1005953
past_key_values=past_key_values,
1006954
inputs_embeds=inputs_embeds,
1007955
use_cache=use_cache,
1008-
output_attentions=output_attentions,
1009-
output_hidden_states=output_hidden_states,
1010956
return_dict=True,
1011957
cache_position=cache_position,
1012958
**lm_kwargs,
1013959
)
1014960

1015961
return Gemma3ModelOutputWithPast(
1016962
last_hidden_state=outputs.last_hidden_state,
1017-
past_key_values=outputs.past_key_values if use_cache else None,
963+
past_key_values=outputs.past_key_values,
1018964
hidden_states=outputs.hidden_states,
1019965
attentions=outputs.attentions,
1020966
image_hidden_states=image_features if pixel_values is not None else None,
@@ -1053,6 +999,7 @@ def set_input_embeddings(self, value):
1053999
def get_image_features(self, pixel_values):
10541000
return self.model.get_image_features(pixel_values)
10551001

1002+
@can_return_tuple
10561003
@auto_docstring
10571004
def forward(
10581005
self,
@@ -1066,11 +1013,8 @@ def forward(
10661013
inputs_embeds: Optional[torch.FloatTensor] = None,
10671014
labels: Optional[torch.LongTensor] = None,
10681015
use_cache: Optional[bool] = None,
1069-
output_attentions: Optional[bool] = None,
1070-
output_hidden_states: Optional[bool] = None,
1071-
return_dict: Optional[bool] = None,
10721016
logits_to_keep: Union[int, torch.Tensor] = 0,
1073-
**lm_kwargs,
1017+
**lm_kwargs: Unpack[TransformersKwargs],
10741018
) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
10751019
r"""
10761020
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1116,13 +1060,6 @@ def forward(
11161060
"user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
11171061
```
11181062
"""
1119-
1120-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1121-
output_hidden_states = (
1122-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1123-
)
1124-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1125-
11261063
outputs = self.model(
11271064
input_ids=input_ids,
11281065
pixel_values=pixel_values,
@@ -1133,9 +1070,6 @@ def forward(
11331070
inputs_embeds=inputs_embeds,
11341071
use_cache=use_cache,
11351072
labels=labels,
1136-
output_attentions=output_attentions,
1137-
output_hidden_states=output_hidden_states,
1138-
return_dict=return_dict,
11391073
cache_position=cache_position,
11401074
**lm_kwargs,
11411075
)
@@ -1167,10 +1101,6 @@ def forward(
11671101
flat_labels = shift_labels.view(-1).to(shift_logits.device)
11681102
loss = loss_fct(flat_logits, flat_labels)
11691103

1170-
if not return_dict:
1171-
output = (logits,) + outputs[1:]
1172-
return (loss,) + output if loss is not None else output
1173-
11741104
return Gemma3CausalLMOutputWithPast(
11751105
loss=loss,
11761106
logits=logits,

0 commit comments

Comments
 (0)