Skip to content

Commit 2df0c32

Browse files
byebye torch 2.1 (#40317)
* Bump minimum torch version to 2.2 * Remove is_torch_greater_or_equal_than_2_2 * update versions table * Deprecate is_torch_sdpa_available (except for backward compat), remove require_torch_sdpa
1 parent c50f140 commit 2df0c32

File tree

60 files changed

+17
-223
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+17
-223
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@
190190
"tiktoken",
191191
"timm<=1.0.19,!=1.0.18",
192192
"tokenizers>=0.21,<0.22",
193-
"torch>=2.1",
193+
"torch>=2.2",
194194
"torchaudio",
195195
"torchvision",
196196
"pyctcdecode>=0.4.0",

src/transformers/dependency_versions_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
"tiktoken": "tiktoken",
9393
"timm": "timm<=1.0.19,!=1.0.18",
9494
"tokenizers": "tokenizers>=0.21,<0.22",
95-
"torch": "torch>=2.1",
95+
"torch": "torch>=2.2",
9696
"torchaudio": "torchaudio",
9797
"torchvision": "torchvision",
9898
"pyctcdecode": "pyctcdecode>=0.4.0",

src/transformers/models/albert/modeling_albert.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from ...pytorch_utils import (
3939
apply_chunking_to_forward,
4040
find_pruneable_heads_and_indices,
41-
is_torch_greater_or_equal_than_2_2,
4241
prune_linear_layer,
4342
)
4443
from ...utils import ModelOutput, auto_docstring, logging
@@ -356,7 +355,6 @@ class AlbertSdpaAttention(AlbertAttention):
356355
def __init__(self, config):
357356
super().__init__(config)
358357
self.dropout_prob = config.attention_probs_dropout_prob
359-
self.require_contiguous_qkv = not is_torch_greater_or_equal_than_2_2
360358

361359
def forward(
362360
self,
@@ -392,14 +390,6 @@ def forward(
392390
.transpose(1, 2)
393391
)
394392

395-
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
396-
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
397-
# Reference: https://github.com/pytorch/pytorch/issues/112577
398-
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
399-
query_layer = query_layer.contiguous()
400-
key_layer = key_layer.contiguous()
401-
value_layer = value_layer.contiguous()
402-
403393
attention_output = torch.nn.functional.scaled_dot_product_attention(
404394
query=query_layer,
405395
key=key_layer,

src/transformers/models/distilbert/modeling_distilbert.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
from ...pytorch_utils import (
4545
apply_chunking_to_forward,
4646
find_pruneable_heads_and_indices,
47-
is_torch_greater_or_equal_than_2_2,
4847
prune_linear_layer,
4948
)
5049
from ...utils import (
@@ -338,7 +337,6 @@ class DistilBertSdpaAttention(MultiHeadSelfAttention):
338337
def __init__(self, config: PretrainedConfig):
339338
super().__init__(config=config)
340339
self.dropout_prob = config.attention_dropout
341-
self.require_contiguous_qkv = not is_torch_greater_or_equal_than_2_2
342340

343341
def forward(
344342
self,
@@ -391,14 +389,6 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
391389
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
392390
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
393391

394-
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
395-
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
396-
# Reference: https://github.com/pytorch/pytorch/issues/112577
397-
if self.require_contiguous_qkv and q.device.type == "cuda" and mask is not None:
398-
q = q.contiguous()
399-
k = k.contiguous()
400-
v = v.contiguous()
401-
402392
attn_output = torch.nn.functional.scaled_dot_product_attention(
403393
q,
404394
k,

src/transformers/pytorch_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@
3838
is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
3939
is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True)
4040
is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)
41-
is_torch_greater_or_equal_than_2_2 = is_torch_greater_or_equal("2.2", accept_dev=True)
4241

4342
# For backwards compatibility (e.g. some remote codes on Hub using those variables).
43+
is_torch_greater_or_equal_than_2_2 = is_torch_greater_or_equal("2.2", accept_dev=True)
4444
is_torch_greater_or_equal_than_2_1 = is_torch_greater_or_equal("2.1", accept_dev=True)
4545
is_torch_greater_or_equal_than_2_0 = is_torch_greater_or_equal("2.0", accept_dev=True)
4646
is_torch_greater_or_equal_than_1_13 = is_torch_greater_or_equal("1.13", accept_dev=True)

src/transformers/testing_utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@
159159
is_torch_neuroncore_available,
160160
is_torch_npu_available,
161161
is_torch_optimi_available,
162-
is_torch_sdpa_available,
163162
is_torch_tensorrt_fx_available,
164163
is_torch_tf32_available,
165164
is_torch_xla_available,
@@ -624,15 +623,6 @@ def require_flash_attn_3(test_case):
624623
return unittest.skipUnless(is_flash_attn_3_available(), "test requires Flash Attention 3")(test_case)
625624

626625

627-
def require_torch_sdpa(test_case):
628-
"""
629-
Decorator marking a test that requires PyTorch's SDPA.
630-
631-
These tests are skipped when requirements are not met (torch version).
632-
"""
633-
return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case)
634-
635-
636626
def require_read_token(test_case):
637627
"""
638628
A decorator that loads the HF token for tests that require to load gated models.

src/transformers/utils/import_utils.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -451,17 +451,10 @@ def get_torch_major_and_minor_version() -> str:
451451

452452

453453
def is_torch_sdpa_available():
454+
# Mostly retained for backward compatibility in remote code, since sdpa works correctly on all torch versions >= 2.2
454455
if not is_torch_available() or _torch_version == "N/A":
455456
return False
456-
457-
# NOTE: MLU is OK with non-contiguous inputs.
458-
if is_torch_mlu_available():
459-
return True
460-
# NOTE: NPU can use SDPA in Transformers with torch>=2.1.0.
461-
if is_torch_npu_available():
462-
return True
463-
# NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577
464-
return version.parse(_torch_version) >= version.parse("2.1.1")
457+
return True
465458

466459

467460
def is_torch_flex_attn_available():

tests/generation/test_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
require_torch_gpu,
5252
require_torch_greater_or_equal,
5353
require_torch_multi_accelerator,
54-
require_torch_sdpa,
5554
set_config_for_less_flaky_test,
5655
set_model_for_less_flaky_test,
5756
set_model_tester_for_less_flaky_test,
@@ -2366,7 +2365,6 @@ def _test_attention_implementation(self, attn_implementation):
23662365
self.assertTrue(has_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3))
23672366

23682367
@pytest.mark.generate
2369-
@require_torch_sdpa
23702368
@slow
23712369
def test_eager_matches_sdpa_generate(self):
23722370
"""Tests that generate has equivalent outputs with SDPA and eager attention implementations."""

tests/models/aimv2/test_modeling_aimv2.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
require_flash_attn,
2929
require_torch,
3030
require_torch_gpu,
31-
require_torch_sdpa,
3231
require_vision,
3332
slow,
3433
torch_device,
@@ -563,7 +562,6 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
563562
)
564563

565564
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
566-
@require_torch_sdpa
567565
def test_eager_matches_sdpa_inference(
568566
self,
569567
name,

tests/models/blip_2/test_modeling_blip_2.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
require_torch_fp16,
3232
require_torch_gpu,
3333
require_torch_multi_accelerator,
34-
require_torch_sdpa,
3534
require_vision,
3635
slow,
3736
torch_device,
@@ -508,7 +507,6 @@ def test_retain_grad_hidden_states_attentions(self):
508507
def test_model_get_set_embeddings(self):
509508
pass
510509

511-
@require_torch_sdpa
512510
def test_sdpa_can_dispatch_composite_models(self):
513511
"""
514512
Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model.
@@ -945,7 +943,6 @@ def test_model_get_set_embeddings(self):
945943
def test_cpu_offload(self):
946944
pass
947945

948-
@require_torch_sdpa
949946
def test_sdpa_can_dispatch_composite_models(self):
950947
"""
951948
Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model.

0 commit comments

Comments
 (0)