From a554b7f8f657916e35ea39998cfd3a52eb193f07 Mon Sep 17 00:00:00 2001 From: default Date: Thu, 4 Dec 2025 17:16:59 +0000 Subject: [PATCH 1/4] Fix GraniteMoeHybridModel._update_mamba_mask for torch.export compatibility --- .../modular_granitemoehybrid.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index 65e729cac9a4..5ded3cf48ed7 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -25,7 +25,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, logging +from ...utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling, logging from ...utils.generic import check_model_inputs from ..bamba.configuration_bamba import BambaConfig from ..bamba.modeling_bamba import BambaMixer, BambaRMSNormGated, HybridMambaAttentionDynamicCache @@ -276,10 +276,23 @@ def _update_mamba_mask(self, attention_mask, cache_position): 1. Cached forward 2. Attending to all inputs """ - mamba_mask = attention_mask - if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): - mamba_mask = None - return mamba_mask + cached = cache_position[0] > 0 + all_attend = torch.all(attention_mask == 1) + pred = cached | all_attend + + if not is_torchdynamo_compiling: + # keep original None if not exporting + return None if bool(pred) else attention_mask + + # compiling/exporting -> always return tensor + def true_fn(mask): + # return a tensor of ones instead of None + return torch.ones_like(mask) + + def false_fn(mask): + return mask + + return torch.cond(pred, true_fn, false_fn, (attention_mask,)) class GraniteMoeHybridForCausalLM(GraniteMoeSharedForCausalLM): From e29509eccfd4304076e86159d062efb285db7c53 Mon Sep 17 00:00:00 2001 From: juanigp Date: Fri, 5 Dec 2025 09:24:18 +0100 Subject: [PATCH 2/4] fix typo --- .../models/granitemoehybrid/modular_granitemoehybrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index 5ded3cf48ed7..e78303e50fcc 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -280,7 +280,7 @@ def _update_mamba_mask(self, attention_mask, cache_position): all_attend = torch.all(attention_mask == 1) pred = cached | all_attend - if not is_torchdynamo_compiling: + if not is_torchdynamo_compiling(): # keep original None if not exporting return None if bool(pred) else attention_mask From e3da7e063196bc474258ec906e8a554f1183dd2e Mon Sep 17 00:00:00 2001 From: juanigp Date: Fri, 5 Dec 2025 09:29:54 +0100 Subject: [PATCH 3/4] add _update_mamba_mask eager exit --- .../models/granitemoehybrid/modular_granitemoehybrid.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index e78303e50fcc..c485c9876814 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -276,6 +276,10 @@ def _update_mamba_mask(self, attention_mask, cache_position): 1. Cached forward 2. Attending to all inputs """ + # eager exit if None + if attention_mask is None: + return None + cached = cache_position[0] > 0 all_attend = torch.all(attention_mask == 1) pred = cached | all_attend From 92e6c83e3fd3e4628a7ef39228caa36fb8526490 Mon Sep 17 00:00:00 2001 From: juanigp Date: Fri, 5 Dec 2025 10:43:49 +0100 Subject: [PATCH 4/4] Update modular_granitemoehybrid.py --- .../models/granitemoehybrid/modular_granitemoehybrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index c485c9876814..1bbf1b0f0cae 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -294,7 +294,7 @@ def true_fn(mask): return torch.ones_like(mask) def false_fn(mask): - return mask + return mask.clone() return torch.cond(pred, true_fn, false_fn, (attention_mask,))