3333from ...generation import GenerationMixin
3434from ...integrations import use_kernel_func_from_hub
3535from ...masking_utils import create_causal_mask , create_masks_for_generate , create_sliding_window_causal_mask
36- from ...modeling_flash_attention_utils import FlashAttentionKwargs
3736from ...modeling_layers import GenericForSequenceClassification , GradientCheckpointingLayer
3837from ...modeling_outputs import BaseModelOutputWithPast , CausalLMOutputWithPast , SequenceClassifierOutputWithPast
3938from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS , dynamic_rope_update
@@ -347,7 +346,7 @@ def forward(
347346 attention_mask : Optional [torch .Tensor ] = None ,
348347 past_key_values : Optional [Cache ] = None ,
349348 cache_position : Optional [torch .LongTensor ] = None ,
350- ** kwargs : Unpack [FlashAttentionKwargs ],
349+ ** kwargs : Unpack [TransformersKwargs ],
351350 ) -> tuple [torch .Tensor , Optional [torch .Tensor ], Optional [tuple [torch .Tensor ]]]:
352351 input_shape = hidden_states .shape [:- 1 ]
353352 hidden_shape = (* input_shape , - 1 , self .head_dim )
@@ -409,23 +408,19 @@ def forward(
409408 attention_mask : Optional [torch .Tensor ] = None ,
410409 position_ids : Optional [torch .LongTensor ] = None ,
411410 past_key_values : Optional [Cache ] = None ,
412- output_attentions : Optional [bool ] = False ,
413- use_cache : Optional [bool ] = False ,
414411 cache_position : Optional [torch .LongTensor ] = None ,
415- ** kwargs ,
412+ ** kwargs : Unpack [ TransformersKwargs ] ,
416413 ) -> tuple [torch .FloatTensor , Optional [tuple [torch .FloatTensor , torch .FloatTensor ]]]:
417414 residual = hidden_states
418415
419416 hidden_states = self .input_layernorm (hidden_states )
420417
421- hidden_states , self_attn_weights = self .self_attn (
418+ hidden_states , _ = self .self_attn (
422419 hidden_states = hidden_states ,
423420 position_embeddings = position_embeddings ,
424421 attention_mask = attention_mask ,
425422 position_ids = position_ids ,
426423 past_key_values = past_key_values ,
427- output_attentions = output_attentions ,
428- use_cache = use_cache ,
429424 cache_position = cache_position ,
430425 ** kwargs ,
431426 )
@@ -438,12 +433,7 @@ def forward(
438433 hidden_states = self .post_feedforward_layernorm (hidden_states )
439434 hidden_states = residual + hidden_states
440435
441- outputs = (hidden_states ,)
442-
443- if output_attentions :
444- outputs += (self_attn_weights ,)
445-
446- return outputs
436+ return hidden_states
447437
448438
449439@auto_docstring
@@ -527,30 +517,16 @@ def forward(
527517 past_key_values : Optional [Cache ] = None ,
528518 inputs_embeds : Optional [torch .FloatTensor ] = None ,
529519 use_cache : Optional [bool ] = None ,
530- output_attentions : Optional [bool ] = None ,
531- output_hidden_states : Optional [bool ] = None ,
532520 cache_position : Optional [torch .LongTensor ] = None ,
533521 ** kwargs : Unpack [TransformersKwargs ],
534522 ) -> BaseModelOutputWithPast :
535- output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
536- output_hidden_states = (
537- output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
538- )
539- use_cache = use_cache if use_cache is not None else self .config .use_cache
540-
541523 if (input_ids is None ) ^ (inputs_embeds is not None ):
542524 raise ValueError ("You must specify exactly one of input_ids or inputs_embeds" )
543525
544- if self .gradient_checkpointing and self .training and use_cache :
545- logger .warning_once (
546- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
547- )
548- use_cache = False
549-
550526 if inputs_embeds is None :
551527 inputs_embeds = self .embed_tokens (input_ids )
552528
553- if use_cache and past_key_values is None and not self . training :
529+ if use_cache and past_key_values is None :
554530 past_key_values = DynamicCache (config = self .config )
555531
556532 if cache_position is None :
@@ -591,41 +567,22 @@ def forward(
591567 for layer_type in self .config .layer_types :
592568 position_embeddings [layer_type ] = self .rotary_emb (hidden_states , position_ids , layer_type )
593569
594- # decoder layers
595- all_hidden_states = () if output_hidden_states else None
596- all_self_attns = () if output_attentions else None
597-
598570 for decoder_layer in self .layers [: self .config .num_hidden_layers ]:
599- if output_hidden_states :
600- all_hidden_states += (hidden_states ,)
601-
602- layer_outputs = decoder_layer (
571+ hidden_states = decoder_layer (
603572 hidden_states ,
604573 attention_mask = causal_mask_mapping [decoder_layer .attention_type ],
605574 position_embeddings = position_embeddings [decoder_layer .attention_type ],
606575 position_ids = position_ids ,
607576 past_key_values = past_key_values ,
608- output_attentions = output_attentions ,
609- use_cache = use_cache ,
610577 cache_position = cache_position ,
611578 ** kwargs ,
612579 )
613580
614- hidden_states = layer_outputs [0 ]
615-
616- if output_attentions :
617- all_self_attns += (layer_outputs [1 ],)
618-
619581 hidden_states = self .norm (hidden_states )
620582
621- if output_hidden_states :
622- all_hidden_states += (hidden_states ,)
623-
624583 return BaseModelOutputWithPast (
625584 last_hidden_state = hidden_states ,
626585 past_key_values = past_key_values ,
627- hidden_states = all_hidden_states ,
628- attentions = all_self_attns ,
629586 )
630587
631588
@@ -918,10 +875,7 @@ def forward(
918875 inputs_embeds : Optional [torch .FloatTensor ] = None ,
919876 labels : Optional [torch .LongTensor ] = None ,
920877 use_cache : Optional [bool ] = None ,
921- output_attentions : Optional [bool ] = None ,
922- output_hidden_states : Optional [bool ] = None ,
923- return_dict : Optional [bool ] = None ,
924- ** lm_kwargs ,
878+ ** lm_kwargs : Unpack [TransformersKwargs ],
925879 ) -> Union [tuple , Gemma3ModelOutputWithPast ]:
926880 r"""
927881 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -953,12 +907,6 @@ def forward(
953907 if (input_ids is None ) ^ (inputs_embeds is not None ):
954908 raise ValueError ("You must specify exactly one of input_ids or inputs_embeds" )
955909
956- output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
957- output_hidden_states = (
958- output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
959- )
960- return_dict = return_dict if return_dict is not None else self .config .use_return_dict
961-
962910 # Replace image id with PAD if the image token if OOV, to avoid index-errors
963911 if input_ids is not None and self .config .image_token_id >= self .vocab_size :
964912 special_image_mask = input_ids == self .config .image_token_id
@@ -1005,16 +953,14 @@ def forward(
1005953 past_key_values = past_key_values ,
1006954 inputs_embeds = inputs_embeds ,
1007955 use_cache = use_cache ,
1008- output_attentions = output_attentions ,
1009- output_hidden_states = output_hidden_states ,
1010956 return_dict = True ,
1011957 cache_position = cache_position ,
1012958 ** lm_kwargs ,
1013959 )
1014960
1015961 return Gemma3ModelOutputWithPast (
1016962 last_hidden_state = outputs .last_hidden_state ,
1017- past_key_values = outputs .past_key_values if use_cache else None ,
963+ past_key_values = outputs .past_key_values ,
1018964 hidden_states = outputs .hidden_states ,
1019965 attentions = outputs .attentions ,
1020966 image_hidden_states = image_features if pixel_values is not None else None ,
@@ -1053,6 +999,7 @@ def set_input_embeddings(self, value):
1053999 def get_image_features (self , pixel_values ):
10541000 return self .model .get_image_features (pixel_values )
10551001
1002+ @can_return_tuple
10561003 @auto_docstring
10571004 def forward (
10581005 self ,
@@ -1066,11 +1013,8 @@ def forward(
10661013 inputs_embeds : Optional [torch .FloatTensor ] = None ,
10671014 labels : Optional [torch .LongTensor ] = None ,
10681015 use_cache : Optional [bool ] = None ,
1069- output_attentions : Optional [bool ] = None ,
1070- output_hidden_states : Optional [bool ] = None ,
1071- return_dict : Optional [bool ] = None ,
10721016 logits_to_keep : Union [int , torch .Tensor ] = 0 ,
1073- ** lm_kwargs ,
1017+ ** lm_kwargs : Unpack [ TransformersKwargs ] ,
10741018 ) -> Union [tuple , Gemma3CausalLMOutputWithPast ]:
10751019 r"""
10761020 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1116,13 +1060,6 @@ def forward(
11161060 "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
11171061 ```
11181062 """
1119-
1120- output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
1121- output_hidden_states = (
1122- output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
1123- )
1124- return_dict = return_dict if return_dict is not None else self .config .use_return_dict
1125-
11261063 outputs = self .model (
11271064 input_ids = input_ids ,
11281065 pixel_values = pixel_values ,
@@ -1133,9 +1070,6 @@ def forward(
11331070 inputs_embeds = inputs_embeds ,
11341071 use_cache = use_cache ,
11351072 labels = labels ,
1136- output_attentions = output_attentions ,
1137- output_hidden_states = output_hidden_states ,
1138- return_dict = return_dict ,
11391073 cache_position = cache_position ,
11401074 ** lm_kwargs ,
11411075 )
@@ -1167,10 +1101,6 @@ def forward(
11671101 flat_labels = shift_labels .view (- 1 ).to (shift_logits .device )
11681102 loss = loss_fct (flat_logits , flat_labels )
11691103
1170- if not return_dict :
1171- output = (logits ,) + outputs [1 :]
1172- return (loss ,) + output if loss is not None else output
1173-
11741104 return Gemma3CausalLMOutputWithPast (
11751105 loss = loss ,
11761106 logits = logits ,
0 commit comments