diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 555caaf9..72a61102 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -186,17 +186,19 @@ def _tpu_flash_attention( kv_max_block_size = key.shape[1] else: kv_max_block_size = q_max_block_size - if flash_block_sizes: + # ensure that for cross attention we override the block sizes. + if flash_block_sizes and key.shape[1] == query.shape[1]: block_sizes = flash_block_sizes else: + block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size block_sizes = splash_attention_kernel.BlockSizes( - block_q=min(q_max_block_size, query.shape[2]), + block_q=block_size_q, block_kv_compute=min(kv_max_block_size, key.shape[2]), block_kv=min(kv_max_block_size, key.shape[2]), - block_q_dkv=min(q_max_block_size, query.shape[2]), + block_q_dkv=block_size_q, block_kv_dkv=min(kv_max_block_size, key.shape[2]), block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]), - block_q_dq=None if attention_kernel == "tokamax_flash" else block_sizes.block_q_dq, + block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q, block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]), use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False, ) @@ -215,7 +217,6 @@ def _tpu_flash_attention( check_rep=False, ) def wrap_flash_attention(query, key, value): - uses_fused_kernel = block_sizes.use_fused_bwd_kernel block_q_sizes = ( block_sizes.block_q, @@ -1042,7 +1043,6 @@ def setup(self): ) def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None): - qkv_proj = self.qkv(hidden_states) B, L = hidden_states.shape[:2] H, D, K = self.heads, qkv_proj.shape[-1] // (self.heads * 3), 3 @@ -1054,7 +1054,6 @@ def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=Non key_proj = self.key_norm(key_proj) if encoder_hidden_states is not None: - encoder_qkv_proj = self.encoder_qkv(encoder_hidden_states) B, L = encoder_hidden_states.shape[:2] H, D, K = self.heads, encoder_qkv_proj.shape[-1] // (self.heads * 3), 3 @@ -1148,7 +1147,6 @@ class FlaxAttention(nn.Module): quant: Quant = None def setup(self): - if self.attention_kernel == "flash" and self.mesh is None: raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") inner_dim = self.dim_head * self.heads