Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, logging
from ...utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling, logging
from ...utils.generic import check_model_inputs
from ..bamba.configuration_bamba import BambaConfig
from ..bamba.modeling_bamba import BambaMixer, BambaRMSNormGated, HybridMambaAttentionDynamicCache
Expand Down Expand Up @@ -276,10 +276,27 @@ def _update_mamba_mask(self, attention_mask, cache_position):
1. Cached forward
2. Attending to all inputs
"""
mamba_mask = attention_mask
if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
mamba_mask = None
return mamba_mask
# eager exit if None
if attention_mask is None:
return None

cached = cache_position[0] > 0
all_attend = torch.all(attention_mask == 1)
pred = cached | all_attend

if not is_torchdynamo_compiling():
# keep original None if not exporting
return None if bool(pred) else attention_mask

# compiling/exporting -> always return tensor
def true_fn(mask):
# return a tensor of ones instead of None
return torch.ones_like(mask)

def false_fn(mask):
return mask.clone()

return torch.cond(pred, true_fn, false_fn, (attention_mask,))


class GraniteMoeHybridForCausalLM(GraniteMoeSharedForCausalLM):
Expand Down