Skip to content

Commit 76db8da

Browse files
Merge remote-tracking branch 'upstream/main' into batch_spec_dec
2 parents 7a6001c + ff13eb6 commit 76db8da

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+3319
-1451
lines changed

docs/source/en/_toctree.yml

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
title: Legacy model contribution
2424
- local: auto_docstring
2525
title: Documenting a model
26-
- local: attention_interface
27-
title: Customizing attention function
2826
title: Models
2927
- sections:
3028
- local: fast_tokenizers
@@ -61,11 +59,29 @@
6159
- local: llm_tutorial
6260
title: Text generation
6361
- local: generation_strategies
64-
title: Generation strategies
62+
title: Decoding methods
6563
- local: generation_features
6664
title: Generation features
6765
- local: tasks/prompting
6866
title: Prompt engineering
67+
- local: perplexity
68+
title: Perplexity of fixed-length models
69+
title: Generate API
70+
- sections:
71+
- local: attention_interface
72+
title: Attention backends
73+
- local: continuous_batching
74+
title: Continuous batching
75+
- local: kernel_doc/overview
76+
title: Kernels in transformers
77+
- local: perf_torch_compile
78+
title: torch.compile
79+
- local: perf_infer_gpu_one
80+
title: GPU
81+
- local: perf_infer_gpu_multi
82+
title: Distributed inference
83+
- local: perf_infer_cpu
84+
title: CPU
6985
- local: llm_optims
7086
title: Optimizing inference
7187
- local: cache_explanation
@@ -74,9 +90,7 @@
7490
title: KV cache strategies
7591
- local: llm_tutorial_optimization
7692
title: Getting the most out of LLMs
77-
- local: perplexity
78-
title: Perplexity of fixed-length models
79-
title: LLMs
93+
title: Optimization
8094
- sections:
8195
- local: conversations
8296
title: Chat basics
@@ -101,24 +115,12 @@
101115
- local: open_webui
102116
title: Open WebUI
103117
title: Serving
104-
- sections:
105-
- local: perf_torch_compile
106-
title: torch.compile
107-
- local: perf_infer_gpu_one
108-
title: GPU
109-
- local: perf_infer_gpu_multi
110-
title: Distributed inference
111-
- local: perf_infer_cpu
112-
title: CPU
113-
title: Optimization
114118
- local: agents
115119
title: Agents
116120
- local: tools
117121
title: Tools
118122
- local: transformers_as_backend
119123
title: Transformers as modeling backend
120-
- local: continuous_batching
121-
title: Continuous Batching
122124
title: Inference
123125
- isExpanded: false
124126
sections:
@@ -218,11 +220,6 @@
218220
- local: quantization/contribute
219221
title: Contribute
220222
title: Quantization
221-
- isExpanded: false
222-
sections:
223-
- local: kernel_doc/overview
224-
title: Kernels in transformers
225-
title: Kernels
226223
- isExpanded: false
227224
sections:
228225
- local: serialization
@@ -904,6 +901,8 @@
904901
title: Hubert
905902
- local: model_doc/kyutai_speech_to_text
906903
title: Kyutai Speech-To-Text
904+
- local: model_doc/lasr
905+
title: LASR
907906
- local: model_doc/mimi
908907
title: Mimi
909908
- local: model_doc/mms

docs/source/en/attention_interface.md

Lines changed: 105 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -13,103 +13,145 @@ rendered properly in your Markdown viewer.
1313
1414
-->
1515

16-
# Attention Interface
16+
# Attention backends
1717

18-
This page describes how to use the `AttentionInterface` in order to register custom attention functions to use with
19-
supported models.
18+
All attention implementations perform the same computation. Every token is compared to every other token. The difference is *how* the computation is performed. Basic attention scales poorly because it materializes the full attention matrix in memory, creating bottlenecks that slow down inference. Optimized implementations rearrange the math to reduce memory traffic for faster, more affordable inference.
2019

21-
## Customizing attention function
20+
The [`AttentionInterface`] provides optimized attention implementations. It decouples the attention implementation from the model implementation to simplify experimentation with different functions. Add new backends easily with this consistent interface.
2221

23-
Most recent models can now switch from one attention function used in the Attention layer to the other, thanks to a simple mapping.
24-
By default, we provide the implementation for [`sdpa`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html),
25-
[`flash_attention_2`](https://github.com/Dao-AILab/flash-attention) and [`flex_attention`](https://pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention)
26-
as well as `eager`, which is a simple matrix multiplication without any optimization on top.
27-
This is the setting you can usually choose when instantiating a model:
22+
| attention backend | description |
23+
|---|---|
24+
| `"flash_attention_3"` | improves FlashAttention-2 by also overlapping operations and fusing forward and backward passes more tightly |
25+
| `"flash_attention_2"` | tiles computations into smaller blocks and uses fast on-chip memory |
26+
| `"flex_attention"` | framework for specifying custom attention patterns (sparse, block-local, sliding window) without writing low-level kernels by hand |
27+
| `"sdpa"` | built-in PyTorch implementation of [scaled dot product attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) |
28+
| <code>"paged&#124;flash_attention_2"</code> | Paged version of FlashAttention-2 |
29+
| <code>"paged&#124;sdpa"</code> | Paged version of SDPA |
30+
| <code>"paged&#124;eager"</code> | Paged version of eager |
2831

29-
```python
30-
from transformers import AutoModelForCausalLM
32+
## Set an attention backend
3133

32-
model_id = "meta-llama/Llama-3.2-1B"
34+
Use the `attn_implementation` argument in [`~PreTrainedModel.from_pretrained`] to instantiate a model with a specific attention function.
35+
36+
```py
37+
import torch
38+
from transformers import AutoModelForCausalLM
3339

34-
# Here, using flash attention as an example
35-
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2")
40+
model = AutoModelForCausalLM.from_pretrained(
41+
"meta-llama/Llama-3.2-1B", attn_implementation="flash_attention_2"
42+
)
3643
```
3744

38-
But what if you wanted to create your own attention function? Or simply play around with existing ones, adding
39-
a few statements here and there? You can now do so with the `AttentionInterface`! Here is an example:
45+
Switch between attention backends at runtime without reloading the model using [`~PreTrainedModel.set_attn_implementation`].
4046

41-
```python
42-
from transformers import AutoModelForCausalLM, AttentionInterface
43-
from transformers.integrations.sdpa_attention import sdpa_attention_forward
44-
import torch
47+
```py
48+
model.set_attn_implementation("sdpa")
49+
```
4550

46-
model_id = "meta-llama/Llama-3.2-1B"
51+
### Kernels
4752

48-
def my_new_sdpa(*args, **kwargs):
49-
print("I just entered the attention computation")
50-
return sdpa_attention_forward(*args, **kwargs)
53+
Download and load compiled compute kernels directly from the [Hub](https://huggingface.co/models?other=kernels) at runtime with the [Kernels](https://huggingface.co/docs/kernels/index) library. This avoids packaging issues from mismatched PyTorch or CUDA versions.
5154

52-
AttentionInterface.register("my_new_sdpa", my_new_sdpa)
55+
Kernels automatically register to [`AttentionInterface`] upon detection. You don't need to install the FlashAttention package explicitly.
5356

54-
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="my_new_sdpa")
55-
# Try running the forward with the new attention function
56-
model(torch.ones(1, 5, dtype=int))
57+
```py
58+
import torch
59+
from transformers import AutoModelForCausalLM
60+
61+
model = AutoModelForCausalLM.from_pretrained(
62+
"meta-llama/Llama-3.2-1B", attn_implementation="kernels-community/flash-attn2"
63+
)
5764
```
5865

59-
You will see it prints "I just entered the attention computation" as many times as there are layers in the model (with this example, 16 times).
66+
### SDPA context manager
6067

61-
## Dynamically switching attention function
68+
PyTorch's scaled dot product attention (SDPA) selects the fastest attention function for CUDA backends automatically. It defaults to the PyTorch C++ implementation for other backends.
6269

63-
You could dynamically change the model's attention function as well:
70+
Force SDPA to use a specific implementation with the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager.
6471

65-
```python
66-
# Back to use original sdpa implementation
67-
model.set_attn_implementation("sdpa")
72+
```py
73+
import torch
74+
from torch.nn.attention import SDPBackend, sdpa_kernel
75+
from transformers import AutoModelForCausalLM
6876

69-
model(torch.ones(1, 5, dtype=int))
77+
model = AutoModelForCausalLM.from_pretrained(
78+
"meta-llama/Llama-3.2-1B", attn_implementation="sdpa"
79+
)
80+
81+
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
82+
outputs = model.generate(**inputs)
7083
```
7184

72-
and it will stop printing the statements, as it now uses the `sdpa` attention.
73-
This allows to quickly change an attention function, without needing to reload the model!
85+
## Backbone-specific attention
7486

75-
## Different attention per backbone in multimodal models
87+
Multimodal models use different backbones for each modality. Optimize performance by assigning specific attention functions to each backbone. Some vision backbones perform better in fp32, for example, which FlashAttention does not support.
7688

77-
For multimodal models different attention functions may work better for each backbone module. For example, some vision backbones perform better in fp32, but are incompatible with FlashAttention. To continue using FlashAttention while keeping the vision encoder in fp32, create a dict and map each config to an attention implementation as shown below.
89+
Map vision backbones to different attention functions with a dict while the text backbone continues to use FlashAttention. Keys in the attention implementation must match sub-config names.
7890

79-
```python
91+
```py
8092
from transformers import AutoModelForImageTextToText
8193

82-
model_id = "facebook/chameleon-7b"
83-
8494
attention_implementation_per_backbone = {"vision_config": "sdpa", "text_config": "flash_attention_2"}
85-
model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation=attention_implementation_per_backbone)
8695

87-
# NOTE: keys in the attention implementation have to be the same as the sub-config names
8896
for key in attention_implementation_per_backbone:
8997
assert key in model.config.sub_configs, f"Invalid key in `attention_implementation`"
9098

91-
# You can omit certain backbones - the default attention function (SDPA) will be used
92-
# This is equivalent to the previous example
93-
model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation={"text_config": "flash_attention_2"})
99+
model = AutoModelForImageTextToText.from_pretrained(
100+
"facebook/chameleon-7b", attn_implementation=attention_implementation_per_backbone
101+
)
102+
```
103+
104+
Omit certain backbones from the dict to use the default attention function (SDPA).
105+
106+
```py
107+
model = AutoModelForImageTextToText.from_pretrained(
108+
"facebook/chameleon-7b", attn_implementation={"text_config": "flash_attention_2"}
109+
)
110+
```
111+
112+
Set the same attention function for all backbones with a single string.
94113

114+
```py
115+
model = AutoModelForImageTextToText.from_pretrained(
116+
"facebook/chameleon-7b", attn_implementation="eager"
117+
)
118+
```
95119

96-
# Set the same attention implementation for all backbones with single string, same as in non-multimodal models
97-
model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager")
120+
Set the attention function globally with an empty key.
98121

99-
# Alternatively use a dict with an empty key for global configuration
100-
model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation={"": "eager"})
122+
```py
123+
model = AutoModelForImageTextToText.from_pretrained(
124+
"facebook/chameleon-7b", attn_implementation={"": "eager"}
125+
)
101126
```
102127

103-
## What about new args needed in my custom attention function?
128+
## Create a new attention function
129+
130+
Customize or create new attention functions by adding them to the attention registry with [`AttentionInterface.register`]. Models use these functions through the `attn_implementation` argument.
104131

105-
But indeed, what if the new function requires a new arg to be properly used? It's no issue! Models supporting the
106-
`AttentionInterface` propagate kwargs all the way to the Attention layers, and to the used attention function. That way,
107-
you can simply pass the arg (as a kwargs, i.e. you need to qualify the name of the arg) in the model's forward, and it will be correctly used in the attention. However, custom attention functions have some limitations. In particular, it must follow the signature and return format of other attention functions, i.e.
132+
This example customizes the attention function to print a statement for each layer.
108133

109134
```python
135+
import torch
110136
from transformers import AutoModelForCausalLM, AttentionInterface
111137
from transformers.integrations.sdpa_attention import sdpa_attention_forward
138+
139+
def my_new_sdpa(*args, **kwargs):
140+
print("I just entered the attention computation")
141+
return sdpa_attention_forward(*args, **kwargs)
142+
143+
AttentionInterface.register("my_new_sdpa", my_new_sdpa)
144+
145+
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="my_new_sdpa")
146+
model(torch.ones(1, 5, dtype=int))
147+
```
148+
149+
You can also add new arguments to the attention function. Models supporting [`AttentionInterface`] propagate kwargs to attention layers and the attention function. Pass arguments as kwargs in the model's forward function. Custom attention functions must follow this signature and return format.
150+
151+
```python
112152
import torch
153+
from transformers import AutoModelForCausalLM, AttentionInterface
154+
from transformers.integrations.sdpa_attention import sdpa_attention_forward
113155

114156
def custom_attention(
115157
module: torch.nn.Module, # required arg
@@ -127,44 +169,19 @@ def custom_attention(
127169
AttentionInterface.register("custom", custom_attention)
128170

129171
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="custom")
130-
# Forward pass with the new kwargs
131172
model(torch.ones(1, 5, dtype=int), a_new_kwargs=..., another_new_kwargs=...)
132173
```
133174

134-
If in doubt about what args/kwargs a given model sends to the attention function, simply check that model's modeling code on [GitHub](https://github.com/huggingface/transformers/tree/main/src/transformers/models)!
175+
Check a model's [modeling code](https://github.com/huggingface/transformers/tree/main/src/transformers/models) to confirm what arguments and kwargs it sends to the attention function.
135176

136-
## Accessing current available implementations
177+
### AttentionMaskInterface
137178

138-
Most of the time, you will simply need to `register` a new function. If, however, you need to access an existing one,
139-
and/or perform a few checks, the preferred way is to use the global `ALL_ATTENTION_FUNCTIONS`. It behaves the same way you
140-
would expect from a usual Python dictionary:
141-
142-
```python
143-
>>> from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
144-
145-
>>> list(ALL_ATTENTION_FUNCTIONS.keys())
146-
>>> ['flash_attention_2', 'flex_attention', 'sdpa']
147-
148-
>>> ALL_ATTENTION_FUNCTIONS["sdpa"]
149-
>>> <function transformers.integrations.sdpa_attention.sdpa_attention_forward>
150-
151-
>>> ALL_ATTENTION_FUNCTIONS.get("sdpa", None)
152-
>>> <function transformers.integrations.sdpa_attention.sdpa_attention_forward>
153-
154-
# You can also globally `register` a new function directly on it
155-
>>> ALL_ATTENTION_FUNCTIONS.register("new_func", new_func)
156-
```
157-
158-
## Attention Mask Interface
159-
160-
Having a new attention function may mean that you need a new format of attention mask to decide what key and value tokens
161-
the query tokens should attend to. This is now possible with the `AttentionMaskInterface`! It works in the same way as
162-
the `AttentionInterface`:
179+
Configure which key and value tokens queries attend to with [`AttentionMaskInterface`]. Some attention functions require this configuration. Customize the attention mask function and add it to the registry with [`AttentionMaskInterface.register`].
163180

164181
```python
182+
import torch
165183
from transformers import AttentionMaskInterface
166184
from transformers.masking_utils import sdpa_mask
167-
import torch
168185

169186
def my_new_sdpa_mask(*args, **kwargs):
170187
print("I just entered the attention mask computation")
@@ -173,11 +190,9 @@ def my_new_sdpa_mask(*args, **kwargs):
173190
AttentionMaskInterface.register("my_new_sdpa_mask", my_new_sdpa_mask)
174191
```
175192

176-
The reason you have to register it is because we need to automatically correct your mask format based on the attention implementation (for example, flex attention uses a BlockMask format, while sdpa uses a 4D tensor).
177-
By default, if you do not register an attention mask function along with your attention function, mask creation will be skipped
178-
and `attention_mask=None` will be passed along to the Attention layers.
193+
Registered attention masks automatically correct the mask format for the attention implementation. For example, FlexAttention uses a [BlockMask](https://docs.pytorch.org/docs/stable/nn.attention.flex_attention.html?utm_source=chatgpt.com#torch.nn.attention.flex_attention.BlockMask) format, while SDPA uses a 4D tensor. Without a registered attention mask function, mask creation is skipped and `attention_mask=None` passes to the model's attention layers.
179194

180-
The default signature of the attention mask functions is the following:
195+
This is the default signature for an attention mask function.
181196

182197
```python
183198
def custom_attention_mask(
@@ -191,6 +206,6 @@ def custom_attention_mask(
191206
) -> Optional[torch.Tensor]:
192207
```
193208

194-
It mostly works thanks to the `mask_function`, which is a `Callable` in the form of [torch's mask_mod functions](https://pytorch.org/blog/flexattention/), taking 4 indices as input and returning a boolean to indicate if this position should take part in the attention computation.
209+
The `mask_function` argument is a `Callable` that mimics PyTorch's [mask_mod](https://pytorch.org/blog/flexattention/) functions. It takes 4 indices as input and returns a boolean. This boolean indicates if the position contributes to the attention computation.
195210

196-
If you cannot use the `mask_function` to create your mask for some reason, you can try to work around it by doing something similar to our [torch export workaround](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py).
211+
Use this [workaround](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py) for torch export if `mask_function` fails to create a mask.

0 commit comments

Comments
 (0)