You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -13,103 +13,145 @@ rendered properly in your Markdown viewer.
13
13
14
14
-->
15
15
16
-
# Attention Interface
16
+
# Attention backends
17
17
18
-
This page describes how to use the `AttentionInterface` in order to register custom attention functions to use with
19
-
supported models.
18
+
All attention implementations perform the same computation. Every token is compared to every other token. The difference is *how* the computation is performed. Basic attention scales poorly because it materializes the full attention matrix in memory, creating bottlenecks that slow down inference. Optimized implementations rearrange the math to reduce memory traffic for faster, more affordable inference.
20
19
21
-
## Customizing attention function
20
+
The [`AttentionInterface`] provides optimized attention implementations. It decouples the attention implementation from the model implementation to simplify experimentation with different functions. Add new backends easily with this consistent interface.
22
21
23
-
Most recent models can now switch from one attention function used in the Attention layer to the other, thanks to a simple mapping.
24
-
By default, we provide the implementation for [`sdpa`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html),
25
-
[`flash_attention_2`](https://github.com/Dao-AILab/flash-attention) and [`flex_attention`](https://pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention)
26
-
as well as `eager`, which is a simple matrix multiplication without any optimization on top.
27
-
This is the setting you can usually choose when instantiating a model:
22
+
| attention backend | description |
23
+
|---|---|
24
+
|`"flash_attention_3"`| improves FlashAttention-2 by also overlapping operations and fusing forward and backward passes more tightly |
25
+
|`"flash_attention_2"`| tiles computations into smaller blocks and uses fast on-chip memory |
26
+
|`"flex_attention"`| framework for specifying custom attention patterns (sparse, block-local, sliding window) without writing low-level kernels by hand |
27
+
|`"sdpa"`| built-in PyTorch implementation of [scaled dot product attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)|
28
+
| <code>"paged|flash_attention_2"</code> | Paged version of FlashAttention-2 |
29
+
| <code>"paged|sdpa"</code> | Paged version of SDPA |
30
+
| <code>"paged|eager"</code> | Paged version of eager |
28
31
29
-
```python
30
-
from transformers import AutoModelForCausalLM
32
+
## Set an attention backend
31
33
32
-
model_id ="meta-llama/Llama-3.2-1B"
34
+
Use the `attn_implementation` argument in [`~PreTrainedModel.from_pretrained`] to instantiate a model with a specific attention function.
35
+
36
+
```py
37
+
import torch
38
+
from transformers import AutoModelForCausalLM
33
39
34
-
# Here, using flash attention as an example
35
-
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2")
But what if you wanted to create your own attention function? Or simply play around with existing ones, adding
39
-
a few statements here and there? You can now do so with the `AttentionInterface`! Here is an example:
45
+
Switch between attention backends at runtime without reloading the model using [`~PreTrainedModel.set_attn_implementation`].
40
46
41
-
```python
42
-
from transformers import AutoModelForCausalLM, AttentionInterface
43
-
from transformers.integrations.sdpa_attention import sdpa_attention_forward
44
-
import torch
47
+
```py
48
+
model.set_attn_implementation("sdpa")
49
+
```
45
50
46
-
model_id ="meta-llama/Llama-3.2-1B"
51
+
### Kernels
47
52
48
-
defmy_new_sdpa(*args, **kwargs):
49
-
print("I just entered the attention computation")
50
-
return sdpa_attention_forward(*args, **kwargs)
53
+
Download and load compiled compute kernels directly from the [Hub](https://huggingface.co/models?other=kernels) at runtime with the [Kernels](https://huggingface.co/docs/kernels/index) library. This avoids packaging issues from mismatched PyTorch or CUDA versions.
You will see it prints "I just entered the attention computation" as many times as there are layers in the model (with this example, 16 times).
66
+
### SDPA context manager
60
67
61
-
## Dynamically switching attention function
68
+
PyTorch's scaled dot product attention (SDPA) selects the fastest attention function for CUDA backends automatically. It defaults to the PyTorch C++ implementation for other backends.
62
69
63
-
You could dynamically change the model's attention function as well:
70
+
Force SDPA to use a specific implementation with the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager.
64
71
65
-
```python
66
-
# Back to use original sdpa implementation
67
-
model.set_attn_implementation("sdpa")
72
+
```py
73
+
import torch
74
+
from torch.nn.attention import SDPBackend, sdpa_kernel
and it will stop printing the statements, as it now uses the `sdpa` attention.
73
-
This allows to quickly change an attention function, without needing to reload the model!
85
+
## Backbone-specific attention
74
86
75
-
## Different attention per backbonein multimodal models
87
+
Multimodal models use different backbones for each modality. Optimize performance by assigning specific attention functions to each backbone. Some vision backbones perform better in fp32, for example, which FlashAttention does not support.
76
88
77
-
For multimodal models different attention functions may work better for each backbone module. For example, some vision backbones perform better in fp32, but are incompatible with FlashAttention. To continue using FlashAttention while keeping the vision encoder in fp32, create a dict and map each config to an attention implementation as shown below.
89
+
Map vision backbones to different attention functions with a dict while the text backbone continues to use FlashAttention. Keys in the attention implementation must match sub-config names.
78
90
79
-
```python
91
+
```py
80
92
from transformers import AutoModelForImageTextToText
## What about new args needed in my custom attention function?
128
+
## Create a new attention function
129
+
130
+
Customize or create new attention functions by adding them to the attention registry with [`AttentionInterface.register`]. Models use these functions through the `attn_implementation` argument.
104
131
105
-
But indeed, what if the new function requires a new arg to be properly used? It's no issue! Models supporting the
106
-
`AttentionInterface` propagate kwargs all the way to the Attention layers, and to the used attention function. That way,
107
-
you can simply pass the arg (as a kwargs, i.e. you need to qualify the name of the arg) in the model's forward, and it will be correctly used in the attention. However, custom attention functions have some limitations. In particular, it must follow the signature and return format of other attention functions, i.e.
132
+
This example customizes the attention function to print a statement for each layer.
108
133
109
134
```python
135
+
import torch
110
136
from transformers import AutoModelForCausalLM, AttentionInterface
111
137
from transformers.integrations.sdpa_attention import sdpa_attention_forward
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="my_new_sdpa")
146
+
model(torch.ones(1, 5, dtype=int))
147
+
```
148
+
149
+
You can also add new arguments to the attention function. Models supporting [`AttentionInterface`] propagate kwargs to attention layers and the attention function. Pass arguments as kwargs in the model's forward function. Custom attention functions must follow this signature and return format.
150
+
151
+
```python
112
152
import torch
153
+
from transformers import AutoModelForCausalLM, AttentionInterface
154
+
from transformers.integrations.sdpa_attention import sdpa_attention_forward
If in doubt about what args/kwargs a given model sends to the attention function, simply check that model's modeling code on [GitHub](https://github.com/huggingface/transformers/tree/main/src/transformers/models)!
175
+
Check a model's [modeling code](https://github.com/huggingface/transformers/tree/main/src/transformers/models) to confirm what arguments and kwargs it sends to the attention function.
135
176
136
-
## Accessing current available implementations
177
+
### AttentionMaskInterface
137
178
138
-
Most of the time, you will simply need to `register` a new function. If, however, you need to access an existing one,
139
-
and/or perform a few checks, the preferred way is to use the global`ALL_ATTENTION_FUNCTIONS`. It behaves the same way you
Having a new attention function may mean that you need a new format of attention mask to decide what key and value tokens
161
-
the query tokens should attend to. This is now possible with the `AttentionMaskInterface`! It works in the same way as
162
-
the `AttentionInterface`:
179
+
Configure which key and value tokens queries attend to with [`AttentionMaskInterface`]. Some attention functions require this configuration. Customize the attention mask function and add it to the registry with [`AttentionMaskInterface.register`].
163
180
164
181
```python
182
+
import torch
165
183
from transformers import AttentionMaskInterface
166
184
from transformers.masking_utils import sdpa_mask
167
-
import torch
168
185
169
186
def my_new_sdpa_mask(*args, **kwargs):
170
187
print("I just entered the attention mask computation")
The reason you have to register it is because we need to automatically correct your mask format based on the attention implementation (for example, flex attention uses a BlockMask format, while sdpa uses a 4D tensor).
177
-
By default, if you do not register an attention mask function along with your attention function, mask creation will be skipped
178
-
and `attention_mask=None` will be passed along to the Attention layers.
193
+
Registered attention masks automatically correct the mask format for the attention implementation. For example, FlexAttention uses a [BlockMask](https://docs.pytorch.org/docs/stable/nn.attention.flex_attention.html?utm_source=chatgpt.com#torch.nn.attention.flex_attention.BlockMask) format, while SDPA uses a 4D tensor. Without a registered attention mask function, mask creation is skipped and `attention_mask=None` passes to the model's attention layers.
179
194
180
-
The default signature of the attention mask functions is the following:
195
+
This is the default signature for an attention mask function.
181
196
182
197
```python
183
198
defcustom_attention_mask(
@@ -191,6 +206,6 @@ def custom_attention_mask(
191
206
) -> Optional[torch.Tensor]:
192
207
```
193
208
194
-
It mostly works thanks to the `mask_function`, which is a `Callable`in the form of [torch's mask_mod functions](https://pytorch.org/blog/flexattention/), taking 4 indices as input and returning a boolean to indicate if this position should take part in the attention computation.
209
+
The `mask_function` argument is a `Callable`that mimics PyTorch's [mask_mod](https://pytorch.org/blog/flexattention/) functions. It takes 4 indices as input and returns a boolean. This boolean indicates if the position contributes to the attention computation.
195
210
196
-
If you cannot use the `mask_function` to create your mask for some reason, you can try to work around it by doing something similar to our [torch export workaround](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py).
211
+
Use this [workaround](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py) for torch export if `mask_function` fails to create a mask.
0 commit comments