
diff --git a/docs/source/en/perf_infer_gpu_multi.md b/docs/source/en/perf_infer_gpu_multi.md
index 893dd28d7b45..7f835509ce75 100644
--- a/docs/source/en/perf_infer_gpu_multi.md
+++ b/docs/source/en/perf_infer_gpu_multi.md
@@ -306,3 +306,7 @@ The most important part of DTensor is the `placement` attribute because it tells
```
- `Partial()` - Indicates a tensor is pending a reduction operation (not typically relevant for usage in Transformers).
+
+## Resources
+
+Read the [Tensor Parallelism (TP) in Transformers: 5 Minutes to Understand](https://huggingface.co/blog/qgallouedec/tp) blog post for a quick overview of tensor parallelism and learn how column and row parallel setups differ.
\ No newline at end of file
diff --git a/docs/source/en/philosophy.md b/docs/source/en/philosophy.md
index e0a7d082156d..f7941b618da6 100644
--- a/docs/source/en/philosophy.md
+++ b/docs/source/en/philosophy.md
@@ -21,7 +21,7 @@ Transformers is a PyTorch-first library. It provides models that are faithful to
A longer, in-depth article with examples, visualizations and timelines is available [here](https://huggingface.co/spaces/transformers-community/Transformers-tenets) as our canonical reference.
> [!NOTE]
-> Our philosophy evolves through practice. What follows are out current, stable principles.
+> Our philosophy evolves through practice. What follows are our current, stable principles.
## Who this library is for
diff --git a/docs/source/en/quantization/contribute.md b/docs/source/en/quantization/contribute.md
index 1a990f1f0f6f..4481bc7b6225 100644
--- a/docs/source/en/quantization/contribute.md
+++ b/docs/source/en/quantization/contribute.md
@@ -46,15 +46,15 @@ Some quantization methods may require "pre-quantizing" the model through data ca
## Create new HFQuantizer class
+0. The best starting point would be to have a look at another quantization method such as Finegrained Fp8. You will have to update or create three files in total: the [config file](https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py), the [integration file](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/finegrained_fp8.py) and the [quantizer file](https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_finegrained_fp8.py).
+
1. Create a new quantization config class inside [src/transformers/utils/quantization_config.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/utils/quantization_config.py). Add the new quantization config to the [_import_structure](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py#L1088) inside Transformers' [src/transformers/__init__.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py) file.
2. Create a new file inside [src/transformers/quantizers/](https://github.com/huggingface/transformers/tree/abbffc4525566a48a9733639797c812301218b83/src/transformers/quantizers) named `quantizer_your_method.py`, and make it inherit from [`~quantizers.HfQuantizer]. Make sure to add the new quantizer and quantization config in the quantization auto-mapping in [src/transformers/quantizers/auto.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/quantizers/auto.py).
-3. Define the following class attributes and property methods for your quantization method.
+3. Define the following class attributes and property methods for your quantization method:
- `requires_calibration`: Whether the quantization method requires a data calibration process. If set to `True`, you can only support inference (with quantized weights) and not inference and quantization.
- - `required_packages`: A list of strings of the required packages to use the quantized weights. You might need to define some new utility methods such as `is_auto_awq_available` in [transformers/src/utils/import_utils.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/utils/import_utils.py).
- - `requires_parameters_quantization`: Only required if your quantization method requires extra attention to the underlying [nn.Parameter](https://pytorch.org/docs/stable/generated/torch.nn.parameter.Parameter.html) object. For example, bitsandbytes uses [`~bitsandbytes.nn.Params4bit`] and [`~bitsandbytes.nn.Int8Params`], which requires some extra attention when quantizing the model. Most of the recent quantization method packs int2 and int4 weights inside [torch.uint8](https://pytorch.org/docs/stable/tensors.html) weights, so this flag should not be really required (set to `False` by default).
- `is_serializable`: A property method to determine whether the method is serializable or not.
- `is_trainable`: A property method to determine whether you can fine-tune models on top of the quantization method (with or without PEFT approaches).
@@ -62,10 +62,14 @@ Some quantization methods may require "pre-quantizing" the model through data ca
5. Write the `_process_model_before_weight_loading` method. In Transformers, the quantized models are initialized first on the `"meta"` device before loading the weights. This means the `_process_model_before_weight_loading` method takes care of manipulating the model skeleton to replace some modules ([nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)) with the target modules (quantization modules).
- You can define module replacement logic or any other utility method by creating a new file in [transformers/src/integrations/](https://github.com/huggingface/transformers/tree/abbffc4525566a48a9733639797c812301218b83/src/transformers/integrations) and exposing the relevant methods in that folder's `__init__.py` file. The best starting point would be to have a look at another quantization method such as [quantizer_awq.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/quantizers/quantizer_awq.py).
+You can define module replacement logic or any other utility method by creating a new file in [transformers/src/integrations/](https://github.com/huggingface/transformers/tree/abbffc4525566a48a9733639797c812301218b83/src/transformers/integrations) and exposing the relevant methods in that folder's `__init__.py` file.
+
+6. Add the `get_quantize_ops` method to the quantizer class if the quantization supports quantizing on the fly. In transformers, we materialize each tensor and apply a sequence of different operations on it. In our case, the quantization operation happens at the end. You need to create a `XXXQuantize`, a subclass of `ConversionOps`, and add a `convert` method. In the `convert` method, you need to quantize the weights and return a dictionary of quantized params.
+
+7. Add the `get_weight_conversions` method to the quantizer class if the quantization supports loading pre-quantized weights. In transformers, we can collect multiple tensors and apply operations on them. This is particularly useful when we have tensors in the checkpoint that require to be regrouped to re-create the quantized tensors.
-6. Write the `_process_model_after_weight_loading` method. This method enables implementing additional features that require manipulating the model after loading the weights.
+8. Write the `_process_model_after_weight_loading` method if needed. This method enables implementing additional features that require manipulating the model after loading the weights.
-7. Document everything! Make sure your quantization method is documented by adding a new file under `docs/source/en/quantization`.
+9. Document everything! Make sure your quantization method is documented by adding a new file under `docs/source/en/quantization`.
-8. You should add tests by adding the package in our nightly Dockerfile inside `docker/transformers-quantization-latest-gpu` and then adding a new test file in `tests/quantization/xxx`. Feel free to check out existing quantization methods to see how it is implemented.
+10. You should add tests by adding the package in our nightly Dockerfile inside `docker/transformers-quantization-latest-gpu` and then adding a new test file in `tests/quantization/xxx`. Feel free to check out existing quantization methods to see how it is implemented.
diff --git a/docs/source/ja/main_classes/callback.md b/docs/source/ja/main_classes/callback.md
index f1cb357f7eb8..388bd7ff0f69 100644
--- a/docs/source/ja/main_classes/callback.md
+++ b/docs/source/ja/main_classes/callback.md
@@ -37,7 +37,6 @@ rendered properly in your Markdown viewer.
- [`~integrations.WandbCallback`] [wandb](https://www.wandb.com/) がインストールされている場合。
- [`~integrations.CometCallback`] [comet_ml](https://www.comet.com/site/) がインストールされている場合。
- [mlflow](https://www.mlflow.org/) がインストールされている場合は [`~integrations.MLflowCallback`]。
-- [`~integrations.NeptuneCallback`] [neptune](https://neptune.ai/) がインストールされている場合。
- [`~integrations.AzureMLCallback`] [azureml-sdk](https://pypi.org/project/azureml-sdk/) の場合
インストールされています。
- [`~integrations.CodeCarbonCallback`] [codecarbon](https://pypi.org/project/codecarbon/) の場合
@@ -82,8 +81,6 @@ rendered properly in your Markdown viewer.
[[autodoc]] integrations.CodeCarbonCallback
-[[autodoc]] integrations.NeptuneCallback
-
[[autodoc]] integrations.ClearMLCallback
[[autodoc]] integrations.DagsHubCallback
diff --git a/docs/source/ko/main_classes/callback.md b/docs/source/ko/main_classes/callback.md
index c8d122a8ef92..de3dd54ef7f7 100644
--- a/docs/source/ko/main_classes/callback.md
+++ b/docs/source/ko/main_classes/callback.md
@@ -36,7 +36,6 @@ rendered properly in your Markdown viewer.
사용됩니다.
- [`~integrations.CometCallback`]는 [comet_ml](https://www.comet.com/site/)이 설치되어 있으면 사용됩니다.
- [`~integrations.MLflowCallback`]는 [mlflow](https://www.mlflow.org/)가 설치되어 있으면 사용됩니다.
-- [`~integrations.NeptuneCallback`]는 [neptune](https://neptune.ai/)이 설치되어 있으면 사용됩니다.
- [`~integrations.AzureMLCallback`]는 [azureml-sdk](https://pypi.org/project/azureml-sdk/)가 설치되어
있으면 사용됩니다.
- [`~integrations.CodeCarbonCallback`]는 [codecarbon](https://pypi.org/project/codecarbon/)이 설치되어
@@ -82,8 +81,6 @@ rendered properly in your Markdown viewer.
[[autodoc]] integrations.CodeCarbonCallback
-[[autodoc]] integrations.NeptuneCallback
-
[[autodoc]] integrations.ClearMLCallback
[[autodoc]] integrations.DagsHubCallback
diff --git a/docs/source/zh/main_classes/callback.md b/docs/source/zh/main_classes/callback.md
index 36c1898f018b..b80d0da386f6 100644
--- a/docs/source/zh/main_classes/callback.md
+++ b/docs/source/zh/main_classes/callback.md
@@ -30,7 +30,6 @@ Callbacks是“只读”的代码片段,除了它们返回的[TrainerControl]
- [`~integrations.WandbCallback`],如果安装了[wandb](https://www.wandb.com/)。
- [`~integrations.CometCallback`],如果安装了[comet_ml](https://www.comet.com/site/)。
- [`~integrations.MLflowCallback`],如果安装了[mlflow](https://www.mlflow.org/)。
-- [`~integrations.NeptuneCallback`],如果安装了[neptune](https://neptune.ai/)。
- [`~integrations.AzureMLCallback`],如果安装了[azureml-sdk](https://pypi.org/project/azureml-sdk/)。
- [`~integrations.CodeCarbonCallback`],如果安装了[codecarbon](https://pypi.org/project/codecarbon/)。
- [`~integrations.ClearMLCallback`],如果安装了[clearml](https://github.com/allegroai/clearml)。
@@ -71,8 +70,6 @@ Callbacks是“只读”的代码片段,除了它们返回的[TrainerControl]
[[autodoc]] integrations.CodeCarbonCallback
-[[autodoc]] integrations.NeptuneCallback
-
[[autodoc]] integrations.ClearMLCallback
[[autodoc]] integrations.DagsHubCallback
diff --git a/examples/pytorch/README.md b/examples/pytorch/README.md
index c9f288ac36b6..77066900315b 100644
--- a/examples/pytorch/README.md
+++ b/examples/pytorch/README.md
@@ -199,7 +199,6 @@ You can easily log and monitor your runs code. The following are currently suppo
* [TensorBoard](https://www.tensorflow.org/tensorboard)
* [Weights & Biases](https://docs.wandb.ai/integrations/huggingface)
* [Comet ML](https://www.comet.com/docs/v2/integrations/ml-frameworks/transformers/)
-* [Neptune](https://docs.neptune.ai/integrations-and-supported-tools/model-training/hugging-face)
* [ClearML](https://clear.ml/docs/latest/docs/getting_started/ds/ds_first_steps)
* [DVCLive](https://dvc.org/doc/dvclive/ml-frameworks/huggingface)
@@ -256,91 +255,6 @@ or if in a Conda environment:
conda install -c comet_ml -c anaconda -c conda-forge comet_ml
```
-### Neptune
-
-First, install the Neptune client library. You can do it with either `pip` or `conda`:
-
-`pip`:
-
-```bash
-pip install neptune
-```
-
-`conda`:
-
-```bash
-conda install -c conda-forge neptune
-```
-
-Next, in your model training script, import `NeptuneCallback`:
-
-```python
-from transformers.integrations import NeptuneCallback
-```
-
-To enable Neptune logging, in your `TrainingArguments`, set the `report_to` argument to `"neptune"`:
-
-```python
-training_args = TrainingArguments(
- "quick-training-distilbert-mrpc",
- eval_strategy="steps",
- eval_steps=20,
- report_to="neptune",
-)
-
-trainer = Trainer(
- model,
- training_args,
- ...
-)
-```
-
-**Note:** This method requires saving your Neptune credentials as environment variables (see the bottom of the section).
-
-Alternatively, for more logging options, create a Neptune callback:
-
-```python
-neptune_callback = NeptuneCallback()
-```
-
-To add more detail to the tracked run, you can supply optional arguments to `NeptuneCallback`.
-
-Some examples:
-
-```python
-neptune_callback = NeptuneCallback(
- name = "DistilBERT",
- description = "DistilBERT fine-tuned on GLUE/MRPC",
- tags = ["args-callback", "fine-tune", "MRPC"], # tags help you manage runs in Neptune
- base_namespace="callback", # the default is "finetuning"
- log_checkpoints = "best", # other options are "last", "same", and None
- capture_hardware_metrics = False, # additional keyword arguments for a Neptune run
-)
-```
-
-Pass the callback to the Trainer:
-
-```python
-training_args = TrainingArguments(..., report_to=None)
-trainer = Trainer(
- model,
- training_args,
- ...
- callbacks=[neptune_callback],
-)
-```
-
-Now, when you start the training with `trainer.train()`, your metadata will be logged in Neptune.
-
-**Note:** Although you can pass your **Neptune API token** and **project name** as arguments when creating the callback, the recommended way is to save them as environment variables:
-
-| Environment variable | Value |
-| :------------------- | :--------------------------------------------------- |
-| `NEPTUNE_API_TOKEN` | Your Neptune API token. To find and copy it, click your Neptune avatar and select **Get your API token**. |
-| `NEPTUNE_PROJECT` | The full name of your Neptune project (`workspace-name/project-name`). To find and copy it, head to **project settings** → **Properties**. |
-
-For detailed instructions and examples, see the [Neptune docs](https://docs.neptune.ai/integrations/transformers/).
-
### ClearML
To use ClearML, install the clearml package with:
diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py
index ba6d34cd12cb..cac0b3ee3d8d 100755
--- a/examples/pytorch/summarization/run_summarization.py
+++ b/examples/pytorch/summarization/run_summarization.py
@@ -45,6 +45,7 @@
import numpy as np
from datasets import load_dataset
from filelock import FileLock
+from huggingface_hub import is_offline_mode
import transformers
from transformers import (
@@ -61,7 +62,7 @@
Seq2SeqTrainingArguments,
set_seed,
)
-from transformers.utils import check_min_version, is_offline_mode
+from transformers.utils import check_min_version
from transformers.utils.versions import require_version
diff --git a/examples/pytorch/summarization/run_summarization_no_trainer.py b/examples/pytorch/summarization/run_summarization_no_trainer.py
index fdd1a098e2df..bb402e933c5f 100644
--- a/examples/pytorch/summarization/run_summarization_no_trainer.py
+++ b/examples/pytorch/summarization/run_summarization_no_trainer.py
@@ -51,7 +51,7 @@
from accelerate.utils import set_seed
from datasets import load_dataset
from filelock import FileLock
-from huggingface_hub import HfApi
+from huggingface_hub import HfApi, is_offline_mode
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
@@ -66,7 +66,7 @@
SchedulerType,
get_scheduler,
)
-from transformers.utils import check_min_version, is_offline_mode
+from transformers.utils import check_min_version
from transformers.utils.versions import require_version
diff --git a/pyproject.toml b/pyproject.toml
index 54ec1618e384..dc8a22c98c0e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -66,7 +66,8 @@ markers = [
"flash_attn_3_test: marks tests related to flash attention 3 (deselect with '-m \"not flash_attn_3_test\"')",
"flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')",
"bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests",
- "generate: marks tests that use the GenerationTesterMixin"
+ "generate: marks tests that use the GenerationTesterMixin",
+ "is_training_test: marks tests that use the TrainingTesterMixin (deselect with '-m \"not is_training_test\"')",
]
log_cli = 1
log_cli_level = "WARNING"
diff --git a/setup.py b/setup.py
index 7eaad9289e68..6702c47aec8a 100644
--- a/setup.py
+++ b/setup.py
@@ -99,7 +99,6 @@
"blobfile",
"codecarbon>=2.8.1",
"cookiecutter==1.7.3",
- "dataclasses",
"datasets>=2.15.0", # We need either this pin or pyarrow<21.0.0
"deepspeed>=0.9.3",
"diffusers",
@@ -113,7 +112,7 @@
"GitPython<3.1.19",
"hf-doc-builder>=0.3.0",
"hf_xet",
- "huggingface-hub>=1.0.0,<2.0",
+ "huggingface-hub>=1.2.1,<2.0",
"importlib_metadata",
"ipadic>=1.0.0,<2.0",
"jinja2>=3.1.0",
diff --git a/splitted_tests.txt b/splitted_tests.txt
deleted file mode 100644
index ae7dc9d6e8d1..000000000000
--- a/splitted_tests.txt
+++ /dev/null
@@ -1 +0,0 @@
-tests/models/afmoe/test_modeling_afmoe.py
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 6dad9d7c05cb..c3d364bd4736 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -383,6 +383,8 @@
"BayesianDetectorConfig",
"BayesianDetectorModel",
"ClassifierFreeGuidanceLogitsProcessor",
+ "ContinuousBatchingManager",
+ "ContinuousMixin",
"EncoderNoRepeatNGramLogitsProcessor",
"EncoderRepetitionPenaltyLogitsProcessor",
"EosTokenCriteria",
@@ -536,6 +538,8 @@
from .generation import BayesianDetectorModel as BayesianDetectorModel
from .generation import ClassifierFreeGuidanceLogitsProcessor as ClassifierFreeGuidanceLogitsProcessor
from .generation import CompileConfig as CompileConfig
+ from .generation import ContinuousBatchingManager as ContinuousBatchingManager
+ from .generation import ContinuousMixin as ContinuousMixin
from .generation import EncoderNoRepeatNGramLogitsProcessor as EncoderNoRepeatNGramLogitsProcessor
from .generation import EncoderRepetitionPenaltyLogitsProcessor as EncoderRepetitionPenaltyLogitsProcessor
from .generation import EosTokenCriteria as EosTokenCriteria
diff --git a/src/transformers/activations.py b/src/transformers/activations.py
index 08a30b3dd88c..1312bede777e 100644
--- a/src/transformers/activations.py
+++ b/src/transformers/activations.py
@@ -345,8 +345,8 @@ def forward(self, input: Tensor) -> Tensor:
def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
- else:
- raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
+
+ raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
# For backwards compatibility with: from activations import gelu_python
diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py
index 5968bd08d406..24eab78c14fc 100644
--- a/src/transformers/conversion_mapping.py
+++ b/src/transformers/conversion_mapping.py
@@ -228,7 +228,7 @@ def get_model_conversion_mapping(
"""
weight_conversions = []
- # Load models with key mapping
+ # Load models with explicit, user-provided key mapping
if key_mapping is not None:
weight_conversions = [WeightRenaming(source_patterns=k, target_patterns=v) for k, v in key_mapping.items()]
elif any(
diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py
index 6e26e2360a74..01bb9c3770b4 100644
--- a/src/transformers/core_model_loading.py
+++ b/src/transformers/core_model_loading.py
@@ -20,7 +20,7 @@
import re
from abc import abstractmethod
from collections import defaultdict
-from collections.abc import MutableMapping, MutableSet
+from collections.abc import Callable, MutableMapping, MutableSet
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import contextmanager
from copy import deepcopy
@@ -31,7 +31,7 @@
from .integrations.accelerate import offload_weight
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES
-from .utils import is_torch_greater_or_equal, logging
+from .utils import is_env_variable_true, is_torch_greater_or_equal, logging
_torch_distributed_available = torch.distributed.is_available()
@@ -327,10 +327,6 @@ def add_tensor(self, target_key: str, source_key: str, source_pattern: str, futu
self.collected_tensors[source_pattern].append(future)
self.layer_targets[target_key].add(source_key)
- def reset(self) -> None:
- """Clean-up the collected tensors to make sure we don't keep references to past tensors in memory."""
- self.collected_tensors = defaultdict(list)
-
def rename_source_key(self, source_key: str) -> tuple[str, str | None]:
"""
Return a tuple (renamed_key, source_pattern_producing_the_match).
@@ -375,6 +371,32 @@ def reverse_transform(self) -> WeightTransform:
return reverse_transform
+ def materialize_tensors(self) -> dict[str, list[torch.Tensor]]:
+ """
+ Materialize all the tensors that were saved in `self.collected_tensors`. This function removes them from the
+ internal attribute to avoid keeping them in memory during the different `self.convert` operations, and return
+ a new dictionary (otherwise we use more memory than needed during loading).
+
+ We basically have 3 cases here:
+ - async loading (default): the tensors are Future instances that we need to wait for
+ - sync loading: the tensors are Callable, we need to call the Callable to actually load them from disk
+ - saving: the tensors are already torch.Tensor instances (the existing model weights)
+ """
+ collected_tensors = {}
+ for key in set(self.collected_tensors.keys()):
+ # Remove from internal attribute
+ tensors = self.collected_tensors.pop(key)
+ # Async loading
+ if isinstance(tensors[0], Future):
+ tensors = [future.result() for future in tensors]
+ # Sync loading
+ elif callable(tensors[0]):
+ tensors = [func() for func in tensors]
+ # Add them to the new dictionary
+ collected_tensors[key] = tensors
+
+ return collected_tensors
+
@dataclass(slots=True)
class WeightRenaming(WeightTransform):
@@ -389,19 +411,17 @@ def convert(
missing_keys: Optional[MutableSet[str]] = None,
misc: Optional[MutableMapping[str, str]] = None,
):
- # Collect the tensor if using threading
- for pattern, futures in self.collected_tensors.items():
- self.collected_tensors[pattern] = (
- futures if isinstance(futures[0], torch.Tensor) else [future.result() for future in futures]
- )
+ # Collect the tensors here - we use a new dictionary to avoid keeping them in memory in the internal
+ # attribute during the whole process
+ collected_tensors = self.materialize_tensors()
# Perform renaming op (for a simple WeightRenaming, `self.source_patterns` and `self.target_patterns` can
# only be of length 1, and are actually the full key names - we also have only 1 single related tensor)
target_key = self.target_patterns[0]
- collected_tensors = {target_key: self.collected_tensors[self.source_patterns[0]]}
+ collected_tensors = {target_key: collected_tensors[self.source_patterns[0]]}
if hf_quantizer is not None and self.quantization_operation is not None:
- with log_to_misc(layer_name, misc, (self.collected_tensors, layer_name), self.quantization_operation):
+ with log_to_misc(layer_name, misc, (len(collected_tensors), layer_name), self.quantization_operation):
collected_tensors = self.quantization_operation.convert(
collected_tensors,
source_patterns=self.source_patterns,
@@ -437,15 +457,12 @@ def convert(
missing_keys: Optional[MutableSet[str]] = None,
misc: Optional[MutableMapping[str, str]] = None,
):
- # Collect all tensors if using threading
- for pattern, futures in self.collected_tensors.items():
- self.collected_tensors[pattern] = (
- futures if isinstance(futures[0], torch.Tensor) else [future.result() for future in futures]
- )
+ # Collect the tensors here - we use a new dictionary to avoid keeping them in memory in the internal
+ # attribute during the whole process
+ collected_tensors = self.materialize_tensors()
- collected_tensors = self.collected_tensors
for op in self.operations:
- with log_to_misc(layer_name, misc, (collected_tensors, layer_name), op):
+ with log_to_misc(layer_name, misc, (len(collected_tensors), layer_name), op):
collected_tensors = op.convert(
collected_tensors,
source_patterns=self.source_patterns,
@@ -472,7 +489,7 @@ def convert(
pass
if hf_quantizer is not None and self.quantization_operation is not None:
- with log_to_misc(layer_name, misc, (collected_tensors, layer_name), self.quantization_operation):
+ with log_to_misc(layer_name, misc, (len(collected_tensors), layer_name), self.quantization_operation):
collected_tensors = self.quantization_operation.convert(
collected_tensors,
source_patterns=self.source_patterns,
@@ -491,25 +508,46 @@ def convert(
GLOBAL_WORKERS = min(4, os.cpu_count() or 4)
-def _materialize_copy(tensor, device=None, dtype=None):
+def _materialize_copy(tensor: torch.Tensor, device=None, dtype=None) -> torch.Tensor:
+ # This slicing is what actually loads the tensor from the safetensors slice object
tensor = tensor[...]
if dtype is not None or device is not None:
tensor = tensor.to(device=device, dtype=dtype)
return tensor
-def spawn_materialize(thread_pool, tensor, device=None, dtype=None) -> Future:
+def spawn_materialize(
+ thread_pool: ThreadPoolExecutor | None, tensor: torch.Tensor, device=None, dtype=None
+) -> Future | Callable:
+ """Materialize a tensor from file asynchronously if `thread_pool` is provided, or return a Callable that will
+ load the tensor synchronously when called."""
+
def _job():
return _materialize_copy(tensor, device, dtype)
- return thread_pool.submit(_job)
+ if thread_pool is not None:
+ return thread_pool.submit(_job)
+ else:
+ # Return the Callable here, not the Tensor itself, so we actually delay loading to avoid saturating cpu
+ # memory during Conversion
+ return _job
-def spawn_tp_materialize(thread_pool, tensor, sharding_method, tensor_idx, dtype=None) -> Future:
+def spawn_tp_materialize(
+ thread_pool: ThreadPoolExecutor | None, tensor: torch.Tensor, sharding_method, tensor_idx, dtype=None
+) -> Future | Callable:
+ """Materialize and shard a tensor (according to the TP-plan) from file asynchronously if `thread_pool` is provided, or
+ return a Callable that will load the tensor synchronously when called."""
+
def _job():
return sharding_method.shard_tensor(tensor, param_casting_dtype=dtype, tensor_idx=tensor_idx)[0]
- return thread_pool.submit(_job)
+ if thread_pool is not None:
+ return thread_pool.submit(_job)
+ else:
+ # Return the Callable here, not the Tensor itself, so we actually delay loading to avoid saturating cpu
+ # memory during Conversion
+ return _job
def dot_natural_key(s: str):
@@ -545,10 +583,10 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) ->
op_name = _format_op_name(op)
if isinstance(extras, tuple) and len(extras) == 2:
- values, target_keys = extras
+ length, target_keys = extras
descriptor = f"{op_name} " if op_name else ""
misc[first_target_key] = (
- f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values)}"
+ f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {length}"
)
elif isinstance(extras, str):
suffix = f" via {op_name}" if op_name else ""
@@ -783,13 +821,17 @@ def convert_and_load_state_dict_in_model(
misc = {}
mismatch_keys = set()
unexpected_keys = set()
- # Global thread_pool
- thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS)
+
+ # We use threading by default, if not explicitly deactivated via env variable. If we have to offload,
+ # we cannot use it either to control the memory as we are under memory constraints, so we need to be sequential
+ if is_env_variable_true("HF_DEACTIVATE_ASYNC_LOAD") or "disk" in device_map.values():
+ thread_pool = None
+ else:
+ thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS)
renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)]
-
- param_name_to_load: dict[str, Union[WeightRenaming | WeightConverter]] = {}
+ param_name_to_load: dict[str, WeightRenaming | WeightConverter] = {}
# build '(?P
.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'}
# and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched.
@@ -841,8 +883,8 @@ def convert_and_load_state_dict_in_model(
elif empty_param is not None and empty_param.dtype != _dtype:
_dtype = empty_param.dtype # usually correct when initializing
- # 4. Handle TP sharding or device_map placement -> scheduled materialization
- future = None
+ # 4. Handle TP sharding or device_map placement
+ future_or_tensor = None
if device_mesh:
if matched_tp_pattern := tp_plan_alt.search(renamed_key):
matched_tp_pattern = tp_plan_by_group_name[matched_tp_pattern.lastgroup]
@@ -852,7 +894,7 @@ def convert_and_load_state_dict_in_model(
device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone()
)
shard_index = len(mapping.collected_tensors.get(original_key, []))
- future = spawn_tp_materialize(
+ future_or_tensor = spawn_tp_materialize(
thread_pool,
tensor,
mapping.distributed_operation,
@@ -860,14 +902,14 @@ def convert_and_load_state_dict_in_model(
_dtype,
)
- if future is None:
+ if future_or_tensor is None:
device_match = device_map_regex.match(renamed_key)
param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
# If disk, we need to materialize on cpu first
param_device = "cpu" if param_device == "disk" else param_device
- future = spawn_materialize(thread_pool, tensor, param_device, _dtype)
+ future_or_tensor = spawn_materialize(thread_pool, tensor, param_device, _dtype)
- mapping.add_tensor(renamed_key, original_key, source_pattern, future)
+ mapping.add_tensor(renamed_key, original_key, source_pattern, future_or_tensor)
elif source_pattern is not None: # add all target keys as unexpected
mapping = pattern_to_converter[source_pattern]
for k in mapping.target_patterns:
@@ -875,51 +917,58 @@ def convert_and_load_state_dict_in_model(
else:
unexpected_keys.add(renamed_key)
- total_entries = len(param_name_to_load)
- with logging.tqdm(total=total_entries, desc="Loading weights") as pbar:
- for first_param_name, mapping in param_name_to_load.items():
- pbar.update(1)
- pbar.set_postfix({"Materializing param": first_param_name})
- pbar.refresh()
- try:
- realized_value, misc = mapping.convert(
- first_param_name,
- model=model,
- config=model.config,
- hf_quantizer=hf_quantizer,
- missing_keys=missing_keys,
- misc=misc,
- )
- for target_name, param in realized_value.items():
- param = param[0] if isinstance(param, list) else param
- device_match = device_map_regex.match(target_name)
- param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
- # Offloading support
- if param_device == "disk":
- disk_offload_index = offload_and_maybe_resave_param(
- target_name, param, missing_keys, disk_offload_folder, disk_offload_index, mapping
- )
- else:
- set_param_for_module(
- model,
- target_name,
- param,
- mismatch_keys,
- missing_keys,
- misc,
- unexpected_keys,
- mapping.distributed_operation,
- hf_quantizer,
- )
-
- # Cleanup the tensors
- mapping.reset()
- except SkipLayer:
- continue
+ try:
+ total_entries = len(param_name_to_load)
+ with logging.tqdm(total=total_entries, desc="Loading weights") as pbar:
+ for first_param_name, mapping in param_name_to_load.items():
+ pbar.update(1)
+ pbar.set_postfix({"Materializing param": first_param_name})
+ pbar.refresh()
+ try:
+ realized_value, misc = mapping.convert(
+ first_param_name,
+ model=model,
+ config=model.config,
+ hf_quantizer=hf_quantizer,
+ missing_keys=missing_keys,
+ misc=misc,
+ )
+ for target_name, param in realized_value.items():
+ param = param[0] if isinstance(param, list) else param
+ device_match = device_map_regex.match(target_name)
+ param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
+ # Offloading support
+ if param_device == "disk":
+ disk_offload_index = offload_and_maybe_resave_param(
+ target_name, param, missing_keys, disk_offload_folder, disk_offload_index, mapping
+ )
+ else:
+ set_param_for_module(
+ model,
+ target_name,
+ param,
+ mismatch_keys,
+ missing_keys,
+ misc,
+ unexpected_keys,
+ mapping.distributed_operation,
+ hf_quantizer,
+ )
+
+ # Cleanup all the tensors that were gathered before next iteration
+ del realized_value
+
+ except SkipLayer:
+ continue
+
+ # Close the pool, independently of whether the code was interrupted or finished successfully
+ finally:
+ if thread_pool is not None:
+ # `cancel_futures=True` in case the program was interupted, to avoid wasting time on exit
+ thread_pool.shutdown(wait=False, cancel_futures=True)
# Keep the current weight conversion mapping for later saving (in case it was coming directly from the user)
model._weight_conversions = weight_mapping
- thread_pool.shutdown(wait=False)
return missing_keys, unexpected_keys, mismatch_keys, disk_offload_index, misc
diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py
index b9d579b25fbf..00033733d75f 100644
--- a/src/transformers/dependency_versions_table.py
+++ b/src/transformers/dependency_versions_table.py
@@ -9,7 +9,6 @@
"blobfile": "blobfile",
"codecarbon": "codecarbon>=2.8.1",
"cookiecutter": "cookiecutter==1.7.3",
- "dataclasses": "dataclasses",
"datasets": "datasets>=2.15.0",
"deepspeed": "deepspeed>=0.9.3",
"diffusers": "diffusers",
@@ -23,7 +22,7 @@
"GitPython": "GitPython<3.1.19",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"hf_xet": "hf_xet",
- "huggingface-hub": "huggingface-hub>=1.0.0,<2.0",
+ "huggingface-hub": "huggingface-hub>=1.2.1,<2.0",
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0",
"jinja2": "jinja2>=3.1.0",
diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py
index b70879120f73..d797831a26d1 100644
--- a/src/transformers/dynamic_module_utils.py
+++ b/src/transformers/dynamic_module_utils.py
@@ -30,7 +30,7 @@
from types import ModuleType
from typing import Any, Optional, Union
-from huggingface_hub import try_to_load_from_cache
+from huggingface_hub import is_offline_mode, try_to_load_from_cache
from packaging import version
from .utils import (
@@ -38,7 +38,6 @@
TRANSFORMERS_DYNAMIC_MODULE_NAME,
cached_file,
extract_commit_hash,
- is_offline_mode,
logging,
)
from .utils.import_utils import VersionComparison, split_package_version
diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py
index 0bdf7921cb71..ed7a0978b4e1 100644
--- a/src/transformers/feature_extraction_utils.py
+++ b/src/transformers/feature_extraction_utils.py
@@ -22,7 +22,7 @@
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
import numpy as np
-from huggingface_hub import create_repo
+from huggingface_hub import create_repo, is_offline_mode
from .dynamic_module_utils import custom_object_save
from .utils import (
@@ -32,7 +32,6 @@
TensorType,
copy_func,
is_numpy_array,
- is_offline_mode,
is_torch_available,
is_torch_device,
is_torch_dtype,
diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py
index 423cb699b804..0b1dae926255 100644
--- a/src/transformers/file_utils.py
+++ b/src/transformers/file_utils.py
@@ -68,7 +68,6 @@
is_in_notebook,
is_ipex_available,
is_librosa_available,
- is_offline_mode,
is_onnx_available,
is_pandas_available,
is_phonemizer_available,
diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py
index 92ef3184e773..a4728fe693c8 100644
--- a/src/transformers/generation/__init__.py
+++ b/src/transformers/generation/__init__.py
@@ -86,7 +86,11 @@
"StopStringCriteria",
]
_import_structure["continuous_batching"] = [
+ "ContinuousBatchingManager",
"ContinuousMixin",
+ "FIFOScheduler",
+ "PrefillFirstScheduler",
+ "Scheduler",
]
_import_structure["utils"] = [
"GenerationMixin",
@@ -127,7 +131,13 @@
EarlyExitCandidateGenerator,
PromptLookupCandidateGenerator,
)
- from .continuous_batching import ContinuousMixin
+ from .continuous_batching import (
+ ContinuousBatchingManager,
+ ContinuousMixin,
+ FIFOScheduler,
+ PrefillFirstScheduler,
+ Scheduler,
+ )
from .logits_process import (
AlternatingCodebooksLogitsProcessor,
ClassifierFreeGuidanceLogitsProcessor,
diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py
index d224eb96174a..483ff2bce452 100644
--- a/src/transformers/generation/configuration_utils.py
+++ b/src/transformers/generation/configuration_utils.py
@@ -105,8 +105,9 @@ class GenerationConfig(PushToHubMixin):
> Parameters that control the length of the output
max_length (`int`, *optional*, defaults to 20):
- The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
- `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set.
+ `max_new_tokens` is recommended for controlling how many tokens the model generates.
+ `max_length` remains for backward compatibility.
+
max_new_tokens (`int`, *optional*):
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
min_length (`int`, *optional*, defaults to 0):
diff --git a/src/transformers/generation/continuous_batching/__init__.py b/src/transformers/generation/continuous_batching/__init__.py
index 8d6800f7db35..75bf6178e978 100644
--- a/src/transformers/generation/continuous_batching/__init__.py
+++ b/src/transformers/generation/continuous_batching/__init__.py
@@ -15,12 +15,16 @@
from .cache import PagedAttentionCache
from .continuous_api import ContinuousBatchingManager, ContinuousMixin
from .requests import RequestState, RequestStatus
+from .scheduler import FIFOScheduler, PrefillFirstScheduler, Scheduler
__all__ = [
"ContinuousBatchingManager",
"ContinuousMixin",
+ "FIFOScheduler",
"PagedAttentionCache",
+ "PrefillFirstScheduler",
"RequestState",
"RequestStatus",
+ "Scheduler",
]
diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py
index a30497ed5483..5c0951ef8164 100644
--- a/src/transformers/generation/continuous_batching/continuous_api.py
+++ b/src/transformers/generation/continuous_batching/continuous_api.py
@@ -763,15 +763,9 @@ def __init__(
num_kv_padding_intervals: (optional) Number of intervals used to pad the keys/values dimension
allow_prefix_sharing: (optional) Whether to allow prefix sharing if the model has only full attention layers
"""
+ # Reloade paged version if necessary
if "paged|" not in model.config._attn_implementation:
- attn_implementation = f"paged|{model.config._attn_implementation}"
- model.config._attn_implementation = attn_implementation
-
- # lazy loading flash attention including kernel variations
- if "flash" in attn_implementation:
- from ...modeling_flash_attention_utils import lazy_import_paged_flash_attention
-
- lazy_import_paged_flash_attention(attn_implementation)
+ model.set_attn_implementation(f"paged|{model.config._attn_implementation}")
self.model = model.eval()
generation_config = model.generation_config if generation_config is None else generation_config
diff --git a/src/transformers/image_processing_base.py b/src/transformers/image_processing_base.py
index 2a75c96e72b4..cd59b56e2296 100644
--- a/src/transformers/image_processing_base.py
+++ b/src/transformers/image_processing_base.py
@@ -18,7 +18,7 @@
from typing import Any, Optional, TypeVar, Union
import numpy as np
-from huggingface_hub import create_repo
+from huggingface_hub import create_repo, is_offline_mode
from .dynamic_module_utils import custom_object_save
from .feature_extraction_utils import BatchFeature as BaseBatchFeature
@@ -28,7 +28,6 @@
PROCESSOR_NAME,
PushToHubMixin,
copy_func,
- is_offline_mode,
logging,
safe_load_json_file,
)
diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py
index 70142f2bf296..efa2b9e3150e 100644
--- a/src/transformers/integrations/accelerate.py
+++ b/src/transformers/integrations/accelerate.py
@@ -392,6 +392,15 @@ def _get_device_map(
)
else:
inferred_max_memory = get_max_memory(max_memory)
+
+ # If the user does not provide `max_memory`, accelerate sets the WHOLE cpu available memory as available.
+ # This is unwanted, as we don't want to set extremely tight bound and pressure for cpu if we are memory-constrained,
+ # especially if the model uses WeightConverter (because there will be some uncontrollable cpu memory spikes during
+ # the conversions before we resave the weights). In those cases, it's better to offload to disk a bit more
+ # if we were in-between, as otherwise we blow-up cpu memory
+ if max_memory is None:
+ inferred_max_memory["cpu"] *= 0.90
+
if hf_quantizer is not None:
inferred_max_memory = hf_quantizer.adjust_max_memory(inferred_max_memory)
@@ -466,10 +475,10 @@ def expand_device_map(device_map, param_names):
def accelerate_disk_offload(
+ model: "PreTrainedModel",
disk_offload_folder: str | None,
checkpoint_files: list[str] | None,
device_map: dict,
- expected_keys: list[str],
sharded_metadata: dict | None,
dtype: torch.dtype | None,
weight_mapping=None,
@@ -493,7 +502,8 @@ def accelerate_disk_offload(
# In this case, the offload index is simply the existing safetensors (except if using custom weight loading
# Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time)
if is_offloaded_safetensors:
- param_device_map = expand_device_map(device_map, expected_keys)
+ meta_state_dict = model.state_dict()
+ param_device_map = expand_device_map(device_map, meta_state_dict.keys())
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None:
weight_map = dict.fromkeys(safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0])
@@ -502,7 +512,9 @@ def accelerate_disk_offload(
weight_map = {k: os.path.join(folder, v) for k, v in sharded_metadata["weight_map"].items()}
# Update the weight names according to the `weight_mapping`
- weight_renaming_map = {rename_source_key(k, renamings, [])[0]: k for k in weight_map}
+ weight_renaming_map = {
+ rename_source_key(k, renamings, [], model.base_model_prefix, meta_state_dict)[0]: k for k in weight_map
+ }
# Prepare the index using existing safetensors files
disk_offload_index = {
diff --git a/src/transformers/integrations/bitsandbytes.py b/src/transformers/integrations/bitsandbytes.py
index 9df71fd70d3b..5a04a906b199 100644
--- a/src/transformers/integrations/bitsandbytes.py
+++ b/src/transformers/integrations/bitsandbytes.py
@@ -44,7 +44,7 @@ def convert(
we need to store some parameters to create the quantized weight. For example, bnb requires 6 values that are stored in the checkpoint to recover the quantized weight. So we store them in a dict that it stored in hf_quantizer for now as we can't save it in the op since we create an op per tensor.
"""
value = list(input_dict.values())[0]
- value = value[0] if isinstance(value, list) else value
+ value = value[0]
# update param name to get the weights instead of the quantized stats
module, _ = get_module_from_name(model, full_layer_name)
@@ -223,7 +223,9 @@ def _replace_with_bnb_linear(
if pre_quantized:
# this is kind of an edge case when supporting both loading and quantization ...
# we need to set the right dtype as we cast the checkpoint with the dtype of the meta model
- new_module.weight.data = new_module.weight.data.to(dtype=torch.uint8)
+ new_module.weight.data = new_module.weight.data.to(
+ dtype=quantization_config.bnb_4bit_quant_storage
+ )
model._modules[name] = new_module
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py
index 35f725f9b696..796e7aa4bba2 100644
--- a/src/transformers/integrations/finegrained_fp8.py
+++ b/src/transformers/integrations/finegrained_fp8.py
@@ -606,7 +606,7 @@ def replace_with_fp8_linear(
module_kwargs = {} if pre_quantized else {"dtype": None}
new_module = None
with init_empty_weights():
- if "gate_up_proj" in module_name or "down_proj" in module_name and "experts" in module_name:
+ if module_name.endswith(".experts"):
new_module = FP8Expert(
config=model.config, block_size=quantization_config.weight_block_size, **module_kwargs
)
diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py
index d5600050188f..ec08846fa11d 100644
--- a/src/transformers/integrations/ggml.py
+++ b/src/transformers/integrations/ggml.py
@@ -76,7 +76,7 @@
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
- "qwen2moe": {
+ "qwen2_moe": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py
index 8d3ae310687e..bf2c5b8d168e 100644
--- a/src/transformers/integrations/hub_kernels.py
+++ b/src/transformers/integrations/hub_kernels.py
@@ -78,7 +78,7 @@ def use_kernel_func_from_hub(func_name: str):
)
return lambda func: func
- _KERNEL_MAPPING: dict[str, dict[Device | str, LayerRepository]] = {
+ _KERNEL_MAPPING: dict[str, dict[Device | str, LayerRepository | dict[Mode, LayerRepository]]] = {
"MultiScaleDeformableAttention": {
"cuda": LayerRepository(
repo_id="kernels-community/deformable-detr",
@@ -328,7 +328,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _
if kernel_name in mapping and isinstance(mapping[kernel_name], ModuleType):
return mapping[kernel_name]
if kernel_name not in _HUB_KERNEL_MAPPING:
- logger.warning(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING")
+ logger.warning_once(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING")
mapping[kernel_name] = None
return None
if _kernels_available:
diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py
index 29de72c05415..b89e6062ba7b 100755
--- a/src/transformers/integrations/integration_utils.py
+++ b/src/transformers/integrations/integration_utils.py
@@ -26,6 +26,7 @@
import shutil
import sys
import tempfile
+import warnings
from dataclasses import fields
from enum import Enum
from pathlib import Path
@@ -1455,6 +1456,10 @@ def __init__(self):
class NeptuneCallback(TrainerCallback):
"""TrainerCallback that sends the logs to [Neptune](https://app.neptune.ai).
+ > [!WARNING]
+ > Neptune integration is deprecated and will be removed in a future version of Transformers. We recommend using
+ > other supported experiment tracking integrations.
+
Args:
api_token (`str`, *optional*): Neptune API token obtained upon registration.
You can leave this argument out if you have saved your token to the `NEPTUNE_API_TOKEN` environment
@@ -1500,6 +1505,11 @@ def __init__(
log_checkpoints: str | None = None,
**neptune_run_kwargs,
):
+ warnings.warn(
+ "The NeptuneCallback is deprecated and will be removed in a future version of Transformers. We recommend "
+ "using other supported experiment tracking integrations.",
+ FutureWarning,
+ )
if not is_neptune_available():
raise ValueError(
"NeptuneCallback requires the Neptune client library to be installed. "
diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py
index 4c1c12a4058e..e8519d608c69 100644
--- a/src/transformers/integrations/peft.py
+++ b/src/transformers/integrations/peft.py
@@ -17,6 +17,7 @@
import os
from typing import Any, Literal
+from ..conversion_mapping import get_model_conversion_mapping
from ..core_model_loading import WeightRenaming, rename_source_key
from ..utils import (
CONFIG_NAME,
@@ -46,26 +47,6 @@
logger = logging.get_logger(__name__)
-# DO NOT MODIFY, KEPT FOR BC ONLY
-VLMS = [
- "aria",
- "ayavision",
- "emu3",
- "fuyu",
- "gotocr2",
- "gemma3",
- "internvl",
- "llava", # all llava prefixed models fall under this check
- "mistral3",
- "mllama",
- "paligemma",
- "qwen2vl",
- "qwen2_5_vl",
- "videollava",
- "vipllava",
-]
-
-
class PeftAdapterMixin:
"""
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
@@ -211,11 +192,10 @@ def load_adapter(
if any(conf.peft_type != PeftType.LORA for conf in self.peft_config.values()):
raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.")
+ key_mapping = adapter_kwargs.pop("key_mapping", None) if adapter_kwargs is not None else None
+ weight_conversions = get_model_conversion_mapping(self, key_mapping=key_mapping)
# peft only supports low_cpu_mem_usage starting from v0.13.0
peft_load_kwargs = {}
- key_mapping = adapter_kwargs.pop("key_mapping", None) if adapter_kwargs is not None else None
- if key_mapping is None and any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS):
- key_mapping = self._checkpoint_conversion_mapping
peft_load_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
adapter_name = adapter_name if adapter_name is not None else "default"
@@ -279,9 +259,6 @@ def load_adapter(
)
peft_config.inference_mode = not is_trainable
- if peft_config.peft_type != PeftType.LORA:
- raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.")
-
if not hotswap:
# TODO: WE NEED TOO APPLY OUR DYNAMIC WEIGHT CONVERSION AT SOME POINT HERE!
# Create and add fresh new adapters into the model, unless the weights are hotswapped
@@ -295,17 +272,18 @@ def load_adapter(
# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
renamings = []
- if key_mapping:
- renamings = [entry for entry in key_mapping if isinstance(entry, WeightRenaming)]
+ if weight_conversions:
+ renamings = [entry for entry in weight_conversions if isinstance(entry, WeightRenaming)]
processed_adapter_state_dict = {}
prefix = "base_model.model."
+ state_dict = self.state_dict()
for key, value in adapter_state_dict.items():
if key.startswith(prefix):
new_key = key[len(prefix) :]
else:
new_key = key
- new_key = rename_source_key(new_key, renamings, [])[0]
+ new_key = rename_source_key(new_key, renamings, [], self.base_model_prefix, state_dict)[0]
# For hotswapping, we need the adapter name to be present in the state dict keys
if hotswap:
diff --git a/src/transformers/integrations/torchao.py b/src/transformers/integrations/torchao.py
index 22a776a7ec74..c87c35724b61 100644
--- a/src/transformers/integrations/torchao.py
+++ b/src/transformers/integrations/torchao.py
@@ -32,7 +32,7 @@
if is_torchao_available():
TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))
- if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"):
+ if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.15.0"):
from torchao.prototype.safetensors.safetensors_support import (
unflatten_tensor_state_dict,
)
@@ -210,61 +210,55 @@ def __init__(self, hf_quantizer):
def convert(
self,
input_dict: dict[str, torch.Tensor],
+ source_patterns: list[str] | None = None,
model: Optional[torch.nn.Module] = None,
full_layer_name: str | None = None,
missing_keys=None,
**kwargs,
) -> dict[str, torch.Tensor]:
- if isinstance(self.hf_quantizer.quantization_config.quant_type, str):
- is_int_4 = "int4" in self.hf_quantizer.quantization_config.quant_type
- else:
- config_name = self.hf_quantizer.quantization_config.quant_type.__class__.__name__
- is_int_4 = fuzzy_match_size(config_name) == "4"
-
- # Simple case if we gather layermsnorm weights, we can just return the value since they are not quantized
- if "weight:_data" in input_dict.keys():
- value = (
- input_dict["weight:_data"][0]
- if isinstance(input_dict["weight:_data"], list)
- else input_dict["weight:_data"]
- )
- return {full_layer_name: value}
-
- is_unsafe_serialization = ":" not in list(input_dict.keys())[0]
+ """
+ Consolidates tensor subclass components before reconstructing the object
+
+ For example:
+ input_dict: {
+ "_weight_qdata": torch.Tensor,
+ "_weight_scale": torch.Tensor,
+ }
+ full_layer_name: "model.layers.0.self_attn.k_proj.weight"
+
+ Given this, we reconstruct a Float8Tensor instance using the qdata and scale
+ and return it as a dictionary with the full_layer_name as the key and the recovered
+ Float8Tensor instance as the value.
+ """
+ is_unsafe_serialization = list(input_dict.keys())[0] not in source_patterns
param_data = {}
+ layer_name = ".".join(full_layer_name.split(".")[:-1])
if is_unsafe_serialization:
if isinstance(input_dict["weight"], list):
weight = input_dict["weight"][0]
else:
weight = input_dict["weight"]
else:
- if isinstance(input_dict["weight:qdata"], list):
- param_data[f"{full_layer_name}:qdata"] = input_dict["weight:qdata"][0]
- else:
- param_data[f"{full_layer_name}:qdata"] = input_dict["weight:qdata"]
-
- if isinstance(input_dict["weight:scale"], list):
- param_data[f"{full_layer_name}:scale"] = input_dict["weight:scale"][0]
- else:
- param_data[f"{full_layer_name}:scale"] = input_dict["weight:scale"]
-
- if is_int_4:
- if isinstance(input_dict["weight:zero_point"], list):
- param_data[f"{full_layer_name}:zero_point"] = input_dict["weight:zero_point"][0]
- else:
- param_data[f"{full_layer_name}:zero_point"] = input_dict["weight:zero_point"]
+ for suffix in input_dict.keys():
+ if len(input_dict[suffix]) != 1:
+ raise ValueError(
+ f"Expected a single tensor for {suffix} but got {len(input_dict[suffix])} tensors instead"
+ )
+ param_data[f"{layer_name}.{suffix}"] = input_dict[suffix][0]
- # If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was
- # already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
+ # If it's unsafe-serialized (i.e. not safetensors), no need for anything
if is_unsafe_serialization:
return {full_layer_name: weight}
# Sanity check for the new serialization format
- elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.hf_quantizer.metadata)):
- # print("metadata", self.hf_quantizer.metadata)
- raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")
+ elif not (TORCHAO_VERSION >= version.parse("0.15.0") and is_metadata_torchao(self.hf_quantizer.metadata)):
+ raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.15.0` installed")
- new_param = unflatten_tensor_state_dict(param_data, self.hf_quantizer.metadata)[full_layer_name]
+ unflattened_state_dict, leftover_state_dict = unflatten_tensor_state_dict(
+ param_data, self.hf_quantizer.metadata
+ )
+ assert not leftover_state_dict # there should be no unprocessed tensors
+ new_param = unflattened_state_dict[full_layer_name]
module, _ = get_module_from_name(model, full_layer_name)
# Add repr to the module
diff --git a/src/transformers/modelcard.py b/src/transformers/modelcard.py
index 2a53bb9ba4ff..4cdafea12154 100644
--- a/src/transformers/modelcard.py
+++ b/src/transformers/modelcard.py
@@ -23,7 +23,7 @@
import httpx
import yaml
-from huggingface_hub import model_info
+from huggingface_hub import is_offline_mode, model_info
from huggingface_hub.errors import OfflineModeIsEnabled
from huggingface_hub.utils import HFValidationError
@@ -50,7 +50,6 @@
MODEL_CARD_NAME,
cached_file,
is_datasets_available,
- is_offline_mode,
is_tokenizers_available,
is_torch_available,
logging,
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 7a4db11bb86c..630bae4b6099 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -36,7 +36,7 @@
from zipfile import is_zipfile
import torch
-from huggingface_hub import create_repo, split_torch_state_dict_into_shards
+from huggingface_hub import create_repo, is_offline_mode, split_torch_state_dict_into_shards
from packaging import version
from safetensors import safe_open
from safetensors.torch import save_file as safe_save_file
@@ -85,7 +85,7 @@
verify_tp_plan,
)
from .loss.loss_utils import LOSS_MAPPING
-from .modeling_flash_attention_utils import lazy_import_flash_attention
+from .modeling_flash_attention_utils import lazy_import_flash_attention, lazy_import_paged_flash_attention
from .pytorch_utils import id_tensor_storage
from .quantizers import HfQuantizer
from .quantizers.auto import get_hf_quantizer
@@ -110,7 +110,6 @@
is_flash_attn_2_available,
is_flash_attn_3_available,
is_kernels_available,
- is_offline_mode,
is_torch_flex_attn_available,
is_torch_greater_or_equal,
is_torch_mlu_available,
@@ -1764,9 +1763,12 @@ def _check_and_adjust_attn_implementation(
"""
applicable_attn_implementation = attn_implementation
+ is_paged = attn_implementation is not None and attn_implementation.startswith("paged|")
+
# If FA not installed, do not fail but use kernels instead
requested_original_flash_attn = attn_implementation is not None and (
- attn_implementation == "flash_attention_2" or attn_implementation == "flash_attention_3"
+ attn_implementation.removeprefix("paged|") == "flash_attention_2"
+ or attn_implementation.removeprefix("paged|") == "flash_attention_3"
)
if (
requested_original_flash_attn
@@ -1784,10 +1786,16 @@ def _check_and_adjust_attn_implementation(
else:
applicable_attn_implementation = "kernels-community/vllm-flash-attn3"
+ if is_paged:
+ applicable_attn_implementation = f"paged|{applicable_attn_implementation}"
+
if is_kernel(applicable_attn_implementation):
try:
# preload flash attention here to allow compile with fullgraph
- lazy_import_flash_attention(applicable_attn_implementation)
+ if is_paged:
+ lazy_import_paged_flash_attention(applicable_attn_implementation)
+ else:
+ lazy_import_flash_attention(applicable_attn_implementation)
# log that we used kernel fallback if successful
if requested_original_flash_attn:
@@ -2104,7 +2112,6 @@ def set_decoder(self, decoder):
possible_module_names = ["language_model", "text_model", "decoder"]
for name in possible_module_names:
if hasattr(self, name):
- print(name)
setattr(self, name, decoder)
return
@@ -4039,7 +4046,7 @@ def from_pretrained(
hf_quantizer.postprocess_model(model, config=config) # usually a no-op but sometimes needed
if _adapter_model_path is not None:
- adapter_kwargs["key_mapping"] = weight_conversions # TODO: Dynamic weight loader for adapters
+ adapter_kwargs["key_mapping"] = key_mapping
model.load_adapter(
_adapter_model_path,
adapter_name=adapter_name,
@@ -4090,10 +4097,10 @@ def _load_pretrained_model(
# Prepare parameters offloading if needed
if device_map is not None and "disk" in device_map.values():
disk_offload_index = accelerate_disk_offload(
+ model,
disk_offload_folder,
checkpoint_files,
device_map,
- expected_keys,
sharded_metadata,
dtype,
weight_mapping,
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index fdcd98a85385..71b2155e9bc5 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -186,6 +186,7 @@
from .jetmoe import *
from .kosmos2 import *
from .kyutai_speech_to_text import *
+ from .lasr import *
from .layoutlm import *
from .layoutlmv2 import *
from .layoutlmv3 import *
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 3df37eeb3468..38a0abb9e2d7 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -221,6 +221,8 @@
("kosmos-2", "Kosmos2Config"),
("kosmos-2.5", "Kosmos2_5Config"),
("kyutai_speech_to_text", "KyutaiSpeechToTextConfig"),
+ ("lasr_ctc", "LasrCTCConfig"),
+ ("lasr_encoder", "LasrEncoderConfig"),
("layoutlm", "LayoutLMConfig"),
("layoutlmv2", "LayoutLMv2Config"),
("layoutlmv3", "LayoutLMv3Config"),
@@ -662,6 +664,9 @@
("kosmos-2", "KOSMOS-2"),
("kosmos-2.5", "KOSMOS-2.5"),
("kyutai_speech_to_text", "KyutaiSpeechToText"),
+ ("lasr", "Lasr"),
+ ("lasr_ctc", "Lasr"),
+ ("lasr_encoder", "LasrEncoder"),
("layoutlm", "LayoutLM"),
("layoutlmv2", "LayoutLMv2"),
("layoutlmv3", "LayoutLMv3"),
@@ -977,6 +982,8 @@
("video_llama_3_vision", "video_llama_3"),
("parakeet_encoder", "parakeet"),
("parakeet_ctc", "parakeet"),
+ ("lasr_encoder", "lasr"),
+ ("lasr_ctc", "lasr"),
("wav2vec2-bert", "wav2vec2_bert"),
]
)
diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py
index 4a7ba3272238..a9008af06ab6 100644
--- a/src/transformers/models/auto/feature_extraction_auto.py
+++ b/src/transformers/models/auto/feature_extraction_auto.py
@@ -49,6 +49,8 @@
("granite_speech", "GraniteSpeechFeatureExtractor"),
("hubert", "Wav2Vec2FeatureExtractor"),
("kyutai_speech_to_text", "KyutaiSpeechToTextFeatureExtractor"),
+ ("lasr_ctc", "LasrFeatureExtractor"),
+ ("lasr_encoder", "LasrFeatureExtractor"),
("markuplm", "MarkupLMFeatureExtractor"),
("mimi", "EncodecFeatureExtractor"),
("moonshine", "Wav2Vec2FeatureExtractor"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index dd4997392617..ddd29ad96d5b 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -222,6 +222,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("kosmos-2", "Kosmos2Model"),
("kosmos-2.5", "Kosmos2_5Model"),
("kyutai_speech_to_text", "KyutaiSpeechToTextModel"),
+ ("lasr_ctc", "LasrForCTC"),
+ ("lasr_encoder", "LasrEncoder"),
("layoutlm", "LayoutLMModel"),
("layoutlmv2", "LayoutLMv2Model"),
("layoutlmv3", "LayoutLMv3Model"),
@@ -1583,6 +1585,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
# Model for Connectionist temporal classification (CTC) mapping
("data2vec-audio", "Data2VecAudioForCTC"),
("hubert", "HubertForCTC"),
+ ("lasr_ctc", "LasrForCTC"),
("parakeet_ctc", "ParakeetForCTC"),
("sew", "SEWForCTC"),
("sew-d", "SEWDForCTC"),
diff --git a/src/transformers/models/dac/modeling_dac.py b/src/transformers/models/dac/modeling_dac.py
index d74a369d3d33..c52630c13a18 100644
--- a/src/transformers/models/dac/modeling_dac.py
+++ b/src/transformers/models/dac/modeling_dac.py
@@ -264,7 +264,7 @@ def forward(self, hidden_state):
return hidden_state
-class DacResidualVectorQuantize(nn.Module):
+class DacResidualVectorQuantizer(nn.Module):
"""
ResidualVectorQuantize block - Introduced in SoundStream: An end2end neural audio codec (https://huggingface.co/papers/2107.03312)
"""
@@ -568,7 +568,7 @@ def __init__(self, config: DacConfig):
self.encoder = DacEncoder(config)
self.decoder = DacDecoder(config)
- self.quantizer = DacResidualVectorQuantize(config)
+ self.quantizer = DacResidualVectorQuantizer(config)
self.bits_per_codebook = int(math.log2(self.config.codebook_size))
if 2**self.bits_per_codebook != self.config.codebook_size:
diff --git a/src/transformers/models/eomt/image_processing_eomt_fast.py b/src/transformers/models/eomt/image_processing_eomt_fast.py
index 68fd7bb00744..c0e35c0c12f6 100644
--- a/src/transformers/models/eomt/image_processing_eomt_fast.py
+++ b/src/transformers/models/eomt/image_processing_eomt_fast.py
@@ -239,7 +239,7 @@ def _preprocess(
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
- resized_images_grouped[shape] = stacked_images
+ resized_images_grouped[shape] = stacked_images
images = reorder_images(resized_images_grouped, grouped_images_index)
# Group images by size for batched resizing, Needed in case do_resize is False.
diff --git a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py
index 08b2f0265e79..643eb76ff455 100644
--- a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py
+++ b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py
@@ -514,7 +514,7 @@ def forward(self, hidden_states, attention_mask=None):
Args:
hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
- attention_mask (`torch.Tensor` of shape `(batch, 1, time)`): Attention mask.
+ attention_mask (`torch.Tensor` of shape `(batch, 1, time, time)`): Attention mask.
Returns:
`torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
@@ -530,7 +530,10 @@ def forward(self, hidden_states, attention_mask=None):
# Apply padding mask before convolution
if attention_mask is not None:
- all_masked_rows = torch.all(~attention_mask, dim=-1)
+ if attention_mask.dtype == torch.bool:
+ all_masked_rows = torch.all(~attention_mask, dim=2)
+ else:
+ all_masked_rows = torch.all(~(attention_mask == 0.0), dim=2)
hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
# 1D Depthwise Conv
diff --git a/src/transformers/models/glm4v/configuration_glm4v.py b/src/transformers/models/glm4v/configuration_glm4v.py
index 35c29f07246d..f707ab291a8c 100644
--- a/src/transformers/models/glm4v/configuration_glm4v.py
+++ b/src/transformers/models/glm4v/configuration_glm4v.py
@@ -234,7 +234,9 @@ def __init__(
self.attention_dropout = attention_dropout
self.rope_parameters = rope_parameters
- super().__init__(tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope"}, **kwargs)
+ super().__init__(
+ tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **kwargs
+ )
class Glm4vConfig(PreTrainedConfig):
diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py
index 2cd6c5d0fd06..7f81d03f8ac9 100644
--- a/src/transformers/models/glm4v/modular_glm4v.py
+++ b/src/transformers/models/glm4v/modular_glm4v.py
@@ -271,7 +271,9 @@ def __init__(
self.attention_dropout = attention_dropout
self.rope_parameters = rope_parameters
- super().__init__(tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope"}, **kwargs)
+ super().__init__(
+ tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **kwargs
+ )
class Glm4vConfig(PreTrainedConfig):
diff --git a/src/transformers/models/glm4v_moe/configuration_glm4v_moe.py b/src/transformers/models/glm4v_moe/configuration_glm4v_moe.py
index 20e4f3ad492c..fdfb96f75294 100644
--- a/src/transformers/models/glm4v_moe/configuration_glm4v_moe.py
+++ b/src/transformers/models/glm4v_moe/configuration_glm4v_moe.py
@@ -280,7 +280,9 @@ def __init__(
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.router_aux_loss_coef = router_aux_loss_coef
- super().__init__(tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope"}, **kwargs)
+ super().__init__(
+ tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **kwargs
+ )
class Glm4vMoeConfig(PreTrainedConfig):
diff --git a/src/transformers/models/glm4v_moe/modular_glm4v_moe.py b/src/transformers/models/glm4v_moe/modular_glm4v_moe.py
index 06967fb07642..71c213f940d1 100644
--- a/src/transformers/models/glm4v_moe/modular_glm4v_moe.py
+++ b/src/transformers/models/glm4v_moe/modular_glm4v_moe.py
@@ -227,7 +227,7 @@ def __init__(
self.norm_topk_prob = norm_topk_prob
self.router_aux_loss_coef = router_aux_loss_coef
PreTrainedConfig.__init__(
- self, tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope"}, **kwargs
+ self, tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **kwargs
)
diff --git a/src/transformers/models/lasr/__init__.py b/src/transformers/models/lasr/__init__.py
new file mode 100644
index 000000000000..f4c7c98261a1
--- /dev/null
+++ b/src/transformers/models/lasr/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_lasr import *
+ from .feature_extraction_lasr import *
+ from .modeling_lasr import *
+ from .tokenization_lasr import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/lasr/configuration_lasr.py b/src/transformers/models/lasr/configuration_lasr.py
new file mode 100644
index 000000000000..28051469be58
--- /dev/null
+++ b/src/transformers/models/lasr/configuration_lasr.py
@@ -0,0 +1,244 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/lasr/modular_lasr.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_lasr.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team and Google LLC. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Union
+
+from ...configuration_utils import PreTrainedConfig
+
+
+class LasrEncoderConfig(PreTrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LasrEncoder`]. It is used to instantiate a
+ `LasrEncoder` model according to the specified arguments, defining the model architecture.
+
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 512):
+ Dimension of the layers and the hidden states.
+ num_hidden_layers (`int`, *optional*, defaults to 17):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 2048):
+ Dimension of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the encoder and pooler.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use bias in the attention layers.
+ convolution_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use bias in convolutions of the conformer's convolution module.
+ conv_kernel_size (`int`, *optional*, defaults to 32):
+ The kernel size of the convolution layers in the Conformer block.
+ subsampling_conv_channels (`int`, *optional*, defaults to 256):
+ The number of channels in the subsampling convolution layers.
+ subsampling_conv_kernel_size (`int`, *optional*, defaults to 5):
+ The kernel size of the subsampling convolution layers.
+ subsampling_conv_stride (`int`, *optional*, defaults to 2):
+ The stride of the subsampling convolution layers.
+ num_mel_bins (`int`, *optional*, defaults to 128):
+ Number of mel features.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for all fully connected layers in the embeddings, encoder, and pooler.
+ dropout_positions (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the positions in the input sequence.
+ layerdrop (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the layers in the encoder.
+ activation_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for activations inside the fully connected layer.
+ attention_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention layers.
+ max_position_embeddings (`int`, *optional*, defaults to 10000):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ feed_forward_residual_weights (`tuple[float, float]`, *optional*, defaults to `[1.5, 0.5]`):
+ The residual weights for the feed forward layers.
+ conv_residual_weights (`tuple[float, float]`, *optional*, defaults to `[2.0, 1.0]`):
+ The residual weights for the convolution layers.
+ batch_norm_momentum (`float`, *optional*, defaults to 0.01):
+ The momentum for the batch normalization layers.
+ rope_parameters (`RopeParameters`, *optional*):
+ Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
+ a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
+ with longer `max_position_embeddings`.
+
+ Example:
+ ```python
+ >>> from transformers import LasrEncoderModel, LasrEncoderConfig
+
+ >>> # Initializing a `LasrEncoder` configuration
+ >>> configuration = LasrEncoderConfig()
+
+ >>> # Initializing a model from the configuration
+ >>> model = LasrEncoderModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+
+ This configuration class is based on the LasrEncoder architecture from Google Health AI. You can find more details
+ and pre-trained models at [TODO/TODO](https://huggingface.co/TODO/TODO).
+ """
+
+ model_type = "lasr_encoder"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ hidden_size=512,
+ num_hidden_layers=17,
+ num_attention_heads=8,
+ intermediate_size=2048,
+ hidden_act="silu",
+ attention_bias=False,
+ convolution_bias=False,
+ conv_kernel_size=32,
+ subsampling_conv_channels=256,
+ subsampling_conv_kernel_size=5,
+ subsampling_conv_stride=2,
+ num_mel_bins=128,
+ dropout=0.1,
+ dropout_positions=0.0,
+ layerdrop=0.1,
+ activation_dropout=0.1,
+ attention_dropout=0.1,
+ max_position_embeddings=10000,
+ initializer_range=0.02,
+ layer_norm_eps=1e-6,
+ feed_forward_residual_weights=[1.5, 0.5],
+ conv_residual_weights=[2.0, 1.0],
+ batch_norm_momentum=0.01,
+ rope_parameters=None,
+ **kwargs,
+ ):
+ self.rope_parameters = rope_parameters
+ self.layer_norm_eps = layer_norm_eps
+ self.feed_forward_residual_weights = feed_forward_residual_weights
+ self.conv_residual_weights = conv_residual_weights
+ self.batch_norm_momentum = batch_norm_momentum
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_attention_heads # LlamaAttention compatibility
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.attention_bias = attention_bias
+ self.convolution_bias = convolution_bias
+
+ self.conv_kernel_size = conv_kernel_size
+ self.subsampling_conv_kernel_size = subsampling_conv_kernel_size
+ self.subsampling_conv_stride = subsampling_conv_stride
+ self.subsampling_conv_channels = subsampling_conv_channels
+ self.num_mel_bins = num_mel_bins
+
+ self.dropout = dropout
+ self.dropout_positions = dropout_positions
+ self.layerdrop = layerdrop
+ self.activation_dropout = activation_dropout
+ self.attention_dropout = attention_dropout
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+
+ super().__init__(
+ **kwargs,
+ )
+
+
+class LasrCTCConfig(PreTrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LasrForCTC`]. It is used to instantiate a
+ Lasr CTC model according to the specified arguments, defining the model architecture.
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+ Args:
+ vocab_size (`int`, *optional*, defaults to 512):
+ Vocabulary size of the model.
+ ctc_loss_reduction (`str`, *optional*, defaults to `"mean"`):
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
+ instance of [`LasrForCTC`].
+ ctc_zero_infinity (`bool`, *optional*, defaults to `True`):
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
+ of [`LasrForCTC`].
+ encoder_config (`Union[dict, LasrEncoderConfig]`, *optional*):
+ The config object or dictionary of the encoder.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id. Also used as blank token id.
+ Example:
+ ```python
+ >>> from transformers import LasrForCTC, LasrCTCConfig
+ >>> # Initializing a Lasr configuration
+ >>> configuration = LasrCTCConfig()
+ >>> # Initializing a model from the configuration
+ >>> model = LasrForCTC(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ This configuration class is based on the Lasr CTC architecture from Google Health AI. You can find more details
+ and pre-trained models at [TODO/TODO](https://huggingface.co/TODO/TODO).
+ """
+
+ model_type = "lasr_ctc"
+ sub_configs = {"encoder_config": LasrEncoderConfig}
+
+ def __init__(
+ self,
+ vocab_size=512,
+ ctc_loss_reduction="mean",
+ ctc_zero_infinity=True,
+ encoder_config: Union[dict, LasrEncoderConfig] = None,
+ pad_token_id=0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.ctc_loss_reduction = ctc_loss_reduction
+ self.ctc_zero_infinity = ctc_zero_infinity
+
+ if isinstance(encoder_config, dict):
+ self.encoder_config = LasrEncoderConfig(**encoder_config)
+ elif encoder_config is None:
+ self.encoder_config = LasrEncoderConfig()
+
+ self.encoder_config = self.encoder_config
+ self.initializer_range = self.encoder_config.initializer_range
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_encoder_config(cls, encoder_config: LasrEncoderConfig, **kwargs):
+ r"""
+ Instantiate a [`LasrCTCConfig`] (or a derived class) from lasr encoder model configuration.
+
+ Returns:
+ [`LasrCTCConfig`]: An instance of a configuration object
+ """
+
+ return cls(encoder_config=encoder_config.to_dict(), **kwargs)
+
+
+__all__ = ["LasrEncoderConfig", "LasrCTCConfig"]
diff --git a/src/transformers/models/lasr/feature_extraction_lasr.py b/src/transformers/models/lasr/feature_extraction_lasr.py
new file mode 100644
index 000000000000..50a0229838aa
--- /dev/null
+++ b/src/transformers/models/lasr/feature_extraction_lasr.py
@@ -0,0 +1,277 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team and Google LLC. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional, Union
+
+import numpy as np
+import torch
+
+from ...audio_utils import hertz_to_mel
+from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
+from ...feature_extraction_utils import BatchFeature
+from ...utils import TensorType, logging
+from ...utils.import_utils import requires
+
+
+logger = logging.get_logger(__name__)
+
+
+# TODO: @eustlb, we should be able to remove this and use mel_filter_bank from audio_utils
+def linear_to_mel_weight_matrix(
+ num_mel_bins: int,
+ num_spectrogram_bins: int,
+ sample_rate: float,
+ lower_edge_hertz: float,
+ upper_edge_hertz: float,
+ dtype,
+) -> np.ndarray:
+ """NumPy-port of the JAX mel weight matrix logic."""
+ # We use float64 for precision, matching the JAX implementation.
+ internal_dtype = np.float64
+
+ # HTK excludes the spectrogram DC bin.
+ bands_to_zero = 1
+ nyquist_hertz = sample_rate / 2.0
+ linear_frequencies = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins, dtype=internal_dtype)[bands_to_zero:]
+ spectrogram_bins_mel = hertz_to_mel(linear_frequencies, mel_scale="kaldi")[:, np.newaxis]
+
+ edges = np.linspace(
+ hertz_to_mel(lower_edge_hertz, mel_scale="kaldi"),
+ hertz_to_mel(upper_edge_hertz, mel_scale="kaldi"),
+ num_mel_bins + 2,
+ dtype=internal_dtype,
+ )
+
+ lower_edge_mel, center_mel, upper_edge_mel = (
+ edges[:-2][np.newaxis, :],
+ edges[1:-1][np.newaxis, :],
+ edges[2:][np.newaxis, :],
+ )
+
+ lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / (center_mel - lower_edge_mel)
+ upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / (upper_edge_mel - center_mel)
+ mel_weights_matrix = np.maximum(0.0, np.minimum(lower_slopes, upper_slopes))
+ return np.pad(mel_weights_matrix, [[bands_to_zero, 0], [0, 0]]).astype(dtype)
+
+
+@requires(backends=("torch",))
+class LasrFeatureExtractor(SequenceFeatureExtractor):
+ r"""
+ Constructs a LASR feature extractor.
+
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
+ most of the main methods. Users should refer to this superclass for more information regarding those methods.
+
+ This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the `Short Time
+ Fourier Transform` which should match pytorch's `torch.stft` equivalent.
+
+ Args:
+ feature_size (`int`, *optional*, defaults to 128):
+ The feature dimension of the extracted features.
+ sampling_rate (`int`, *optional*, defaults to 16000):
+ The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
+ hop_length (`int`, *optional*, defaults to 160):
+ Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
+ n_fft (`int`, *optional*, defaults to 512):
+ Size of the Fourier transform.
+ win_length (`int`, *optional*, defaults to 400):
+ The window length for the STFT computation.
+ padding_value (`float`, *optional*, defaults to 0.0):
+ Padding value used to pad the audio. Should correspond to silences.
+ """
+
+ model_input_names = ["input_features", "attention_mask"]
+
+ def __init__(
+ self,
+ feature_size=128,
+ sampling_rate=16000,
+ hop_length=160,
+ n_fft=512,
+ win_length=400,
+ padding_value=0.0,
+ **kwargs,
+ ):
+ super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
+
+ self.hop_length = hop_length
+ self.n_fft = n_fft
+ self.win_length = win_length
+ self.mel_filters = torch.from_numpy(
+ linear_to_mel_weight_matrix(
+ num_mel_bins=feature_size,
+ num_spectrogram_bins=n_fft // 2 + 1,
+ sample_rate=sampling_rate,
+ lower_edge_hertz=125.0,
+ upper_edge_hertz=7500.0,
+ dtype=np.float64,
+ )
+ )
+
+ def _torch_extract_fbank_features(self, waveform, device="cpu"):
+ # spectrogram
+ window = torch.hann_window(self.win_length, periodic=False, device=device, dtype=torch.float64)
+ waveform = waveform.to(torch.float64)
+
+ # TODO: @eustlb, to be standardized
+ # here we cannot use directly torch.stft because every fft frame is padded with zeros
+ # due to unfold then rfft, while torch.stft unfolds with the number of fft points
+ frames = waveform.unfold(-1, self.win_length, self.hop_length)
+ stft = torch.fft.rfft(window * frames, n=self.n_fft)
+ power_spec = torch.abs(stft) ** 2
+
+ # log mel spectrogram
+ mel_filters = self.mel_filters.to(device)
+ mel_spec = torch.clamp(power_spec @ mel_filters, min=1e-5)
+ mel_spec = torch.log(mel_spec)
+
+ return mel_spec
+
+ def __call__(
+ self,
+ raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
+ truncation: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_attention_mask: Optional[bool] = None,
+ padding: Optional[str] = "longest",
+ max_length: Optional[int] = None,
+ sampling_rate: Optional[int] = None,
+ do_normalize: Optional[bool] = None,
+ device: Optional[str] = "cpu",
+ return_token_timestamps: Optional[bool] = None,
+ **kwargs,
+ ) -> BatchFeature:
+ """
+ Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch for
+ the STFT computation if available, otherwise a slower NumPy based one.
+
+ Args:
+ raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
+ The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
+ values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
+ stereo, i.e. single float per timestep.
+ truncation (`bool`, *optional*, default to `True`):
+ Activates truncation to cut input sequences longer than *max_length* to *max_length*.
+ pad_to_multiple_of (`int`, *optional*, defaults to None):
+ If set will pad the sequence to a multiple of the provided value.
+
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
+ return_attention_mask (`bool`, *optional*):
+ Whether to return the attention mask. If left to the default, will return the attention mask according
+ to the specific feature_extractor's default.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+
+
+ For Parakeet models, `attention_mask` should always be passed for batched inference, to avoid subtle
+ bugs.
+
+
+
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ sampling_rate (`int`, *optional*):
+ The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
+ `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
+ pipeline.
+ padding_value (`float`, *optional*, defaults to 0.0):
+ The value that is used to fill the padding values / vectors.
+ do_normalize (`bool`, *optional*, defaults to `False`):
+ Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
+ improve the performance of the model.
+ device (`str`, *optional*, defaults to `'cpu'`):
+ Specifies the device for computation of the log-mel spectrogram of audio signals in the
+ `_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
+ return_token_timestamps (`bool`, *optional*, defaults to `None`):
+ Deprecated. Use `return_attention_mask` instead from which the number of frames can be inferred.
+
+ Whether or not to return the number of frames of the input raw_speech.
+ These num_frames can be used by the model to compute word level timestamps.
+ """
+ if sampling_rate is not None:
+ if sampling_rate != self.sampling_rate:
+ raise ValueError(
+ f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
+ f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
+ f" was sampled with {self.sampling_rate} and not {sampling_rate}."
+ )
+ else:
+ logger.warning(
+ f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
+ "Failing to do so can result in silent errors that might be hard to debug."
+ )
+
+ # Convert to torch tensor
+ if isinstance(raw_speech, np.ndarray):
+ raw_speech = torch.tensor(raw_speech)
+ elif isinstance(raw_speech, (list, tuple)):
+ if isinstance(raw_speech[0], (list, np.ndarray)):
+ raw_speech = [torch.tensor(speech) for speech in raw_speech]
+ else: # list[float]
+ raw_speech = torch.tensor(raw_speech)
+
+ is_batched_torch = isinstance(raw_speech, torch.Tensor) and len(raw_speech.shape) > 1
+ if is_batched_torch and len(raw_speech.shape) > 2:
+ logger.warning(
+ f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
+ "We will take the mean of the channels to convert to mono."
+ )
+ raw_speech = raw_speech.mean(-1)
+
+ is_batched_sequence = isinstance(raw_speech, (list, tuple))
+ if is_batched_sequence:
+ for speech in raw_speech:
+ if len(speech.shape) > 1:
+ logger.warning(
+ f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
+ "We will take the mean of the channels to convert to mono."
+ )
+ speech = speech.mean(-1)
+
+ if is_batched_torch or is_batched_sequence:
+ raw_speech = [speech[:, None].to(torch.float32) for speech in raw_speech]
+ else:
+ raw_speech = [raw_speech[:, None].to(torch.float32)]
+
+ batched_speech = BatchFeature({"input_features": raw_speech})
+ padded_inputs = self.pad(
+ batched_speech,
+ padding=padding,
+ max_length=max_length,
+ truncation=truncation,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ return_tensors="pt",
+ )
+ input_features = padded_inputs.input_features.squeeze(-1)
+ input_features = self._torch_extract_fbank_features(input_features, device)
+ data = {
+ "input_features": input_features.to(torch.float32),
+ }
+
+ if return_attention_mask:
+ attention_mask = padded_inputs.attention_mask[:, self.win_length - 1 :: self.hop_length]
+ data["attention_mask"] = attention_mask.to(torch.bool)
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["LasrFeatureExtractor"]
diff --git a/src/transformers/models/lasr/modeling_lasr.py b/src/transformers/models/lasr/modeling_lasr.py
new file mode 100644
index 000000000000..802ff3ea9ad5
--- /dev/null
+++ b/src/transformers/models/lasr/modeling_lasr.py
@@ -0,0 +1,729 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/lasr/modular_lasr.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_lasr.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team and Google LLC. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...integrations import use_kernel_func_from_hub
+from ...masking_utils import create_bidirectional_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput, CausalLMOutput
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.generic import check_model_inputs
+from .configuration_lasr import LasrCTCConfig, LasrEncoderConfig
+
+
+class LasrEncoderSubsampling(nn.Module):
+ def __init__(self, config: LasrEncoderConfig):
+ super().__init__()
+ self.dense_0 = nn.Linear(config.num_mel_bins, config.hidden_size)
+ self.conv_0 = nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=config.subsampling_conv_kernel_size,
+ stride=config.subsampling_conv_stride,
+ )
+ self.conv_1 = nn.Conv1d(
+ config.hidden_size,
+ config.subsampling_conv_channels,
+ kernel_size=config.subsampling_conv_kernel_size,
+ stride=config.subsampling_conv_stride,
+ )
+ self.dense_1 = nn.Linear(config.subsampling_conv_channels, config.hidden_size)
+ self.act_fn = nn.ReLU()
+
+ def forward(self, input_features: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.act_fn(self.dense_0(input_features))
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.act_fn(self.conv_0(hidden_states))
+ hidden_states = self.act_fn(self.conv_1(hidden_states))
+ hidden_states = hidden_states.transpose(1, 2)
+ return self.dense_1(hidden_states)
+
+
+class LasrEncoderRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: LasrEncoderConfig, device=None):
+ super().__init__()
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+
+ self.rope_type = self.config.rope_parameters["rope_type"]
+ rope_init_fn: Callable = self.compute_default_rope_parameters
+ if self.rope_type != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
+
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = inv_freq
+
+ @staticmethod
+ def compute_default_rope_parameters(
+ config: Optional[LasrEncoderConfig] = None,
+ device: Optional["torch.device"] = None,
+ seq_len: Optional[int] = None,
+ ) -> tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies according to the original RoPE implementation
+ Args:
+ config ([`~transformers.PreTrainedConfig`]):
+ The model configuration.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+ """
+ base = config.rope_parameters["rope_theta"]
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+
+ attention_factor = 1.0 # Unused in this type of RoPE
+
+ # Compute the inverse frequencies
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
+ )
+ return inv_freq, attention_factor
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+@use_kernel_func_from_hub("rotary_pos_emb")
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class LasrEncoderAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: LasrEncoderConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = False
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.rotary_fn = apply_rotary_pos_emb
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class LasrEncoderConvolutionModule(nn.Module):
+ def __init__(self, config: LasrEncoderConfig, module_config=None):
+ """
+ Args:
+ config (LasrEncoderConfig): Configuration for the model.
+ module_config (dict): Configuration for the module (e.g., encoder or decoder).
+ """
+ super().__init__()
+ channels = config.hidden_size
+ # kernel_size should be an odd number for 'SAME' padding
+ if module_config is None:
+ # e.g. using `LasrEncoderEncoderConfig` in src/transformers/models/lasr_encoder/configuration_lasr_encoder.py
+ kernel_size = config.conv_kernel_size
+ self.activation = ACT2FN[getattr(config, "hidden_act", "silu")]
+ else:
+ kernel_size = module_config["kernel_size"]
+ self.activation = ACT2FN[module_config.get("activation", "silu")]
+ self.padding = "same"
+ self.pointwise_conv1 = nn.Conv1d(
+ channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
+ )
+ self.depthwise_conv = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=self.padding,
+ groups=channels,
+ bias=config.convolution_bias,
+ )
+ self.norm = nn.BatchNorm1d(config.hidden_size, momentum=config.batch_norm_momentum)
+ self.pointwise_conv2 = nn.Conv1d(
+ channels, channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
+ )
+
+ def forward(self, hidden_states, attention_mask=None):
+ """
+ Compute convolution module.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
+ attention_mask (`torch.Tensor` of shape `(batch, 1, time, time)`): Attention mask.
+
+ Returns:
+ `torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
+
+ """
+ # exchange the temporal dimension and the feature dimension
+ hidden_states = hidden_states.transpose(1, 2)
+
+ # GLU mechanism, (batch_size, 2*channel, dim)
+ hidden_states = self.pointwise_conv1(hidden_states)
+ # (batch_size, channel, dim)
+ hidden_states = nn.functional.glu(hidden_states, dim=1)
+
+ # Apply padding mask before convolution
+ if attention_mask is not None:
+ if attention_mask.dtype == torch.bool:
+ all_masked_rows = torch.all(~attention_mask, dim=2)
+ else:
+ all_masked_rows = torch.all(~(attention_mask == 0.0), dim=2)
+ hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
+
+ # 1D Depthwise Conv
+ hidden_states = self.depthwise_conv(hidden_states)
+ hidden_states = self.norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.pointwise_conv2(hidden_states)
+
+ return hidden_states.transpose(1, 2)
+
+
+class LasrEncoderFeedForward(nn.Module):
+ def __init__(self, config: LasrEncoderConfig):
+ super().__init__()
+ self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.attention_bias)
+ self.activation = ACT2FN[config.hidden_act]
+ self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.attention_bias)
+ self.activation_dropout = config.activation_dropout
+
+ def forward(self, hidden_states):
+ hidden_states = self.activation(self.linear1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.linear2(hidden_states)
+ return hidden_states
+
+
+class LasrEncoderBlock(GradientCheckpointingLayer):
+ def __init__(self, config: LasrEncoderConfig, layer_idx: int):
+ super().__init__()
+ self.gradient_checkpointing = False
+
+ self.feed_forward1 = LasrEncoderFeedForward(config)
+ self.self_attn = LasrEncoderAttention(config, layer_idx)
+ self.conv = LasrEncoderConvolutionModule(config)
+ self.feed_forward2 = LasrEncoderFeedForward(config)
+
+ self.norm_feed_forward1 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
+ self.norm_self_att = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
+ self.norm_conv = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
+ self.norm_feed_forward2 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
+ self.norm_out = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
+
+ self.feed_forward_residual_weights = config.feed_forward_residual_weights
+ self.conv_residual_weights = config.conv_residual_weights
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.feed_forward1(self.norm_feed_forward1(hidden_states))
+ hidden_states = (
+ self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
+ )
+
+ normalized_hidden_states = self.norm_self_att(hidden_states)
+ attn_output, _ = self.self_attn(
+ hidden_states=normalized_hidden_states,
+ attention_mask=attention_mask,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = hidden_states + attn_output
+
+ conv_output = self.conv(self.norm_conv(hidden_states), attention_mask=attention_mask)
+ hidden_states = self.conv_residual_weights[0] * hidden_states + self.conv_residual_weights[1] * conv_output
+
+ residual = hidden_states
+ hidden_states = self.feed_forward2(self.norm_feed_forward2(hidden_states))
+ hidden_states = (
+ self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
+ )
+
+ hidden_states = self.norm_out(hidden_states)
+
+ return hidden_states
+
+
+@auto_docstring
+class LasrPreTrainedModel(PreTrainedModel):
+ config: LasrCTCConfig
+ base_model_prefix = "model"
+ main_input_name = "input_features"
+ input_modalities = "audio"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LasrEncoderBlock"]
+ _supports_flat_attention_mask = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ # TODO: @eustlb, add support when flash attention supports custom attention bias
+ _supports_flash_attn = False
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": LasrEncoderBlock,
+ "attentions": LasrEncoderAttention,
+ }
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ super()._init_weights(module)
+
+ def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
+ encoder_config = self.config.encoder_config if isinstance(self.config, LasrCTCConfig) else self.config
+ kernel_size = encoder_config.subsampling_conv_kernel_size
+ stride = encoder_config.subsampling_conv_stride
+
+ num_layers = 2
+ for _ in range(num_layers):
+ input_lengths = (input_lengths - kernel_size) // stride + 1
+
+ return input_lengths
+
+ def _get_output_attention_mask(self, attention_mask: torch.Tensor, target_length: Optional[int] = None):
+ """
+ Convert the input attention mask to its subsampled form. `target_length` sets the desired output length, useful
+ when the attention mask length differs from `sum(-1).max()` (i.e., when the longest sequence in the batch is padded)
+ """
+ output_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
+ # Use target_length if provided, otherwise use max length in batch
+ max_length = target_length if target_length is not None else output_lengths.max()
+ attention_mask = torch.arange(max_length, device=attention_mask.device) < output_lengths[:, None]
+ return attention_mask
+
+
+@auto_docstring(
+ custom_intro="""
+ The LasrEncoder model, based on the Conformer architecture](https://arxiv.org/abs/2005.08100).
+ """
+)
+class LasrEncoder(LasrPreTrainedModel):
+ config: LasrEncoderConfig
+ base_model_prefix = "encoder"
+
+ def __init__(self, config: LasrEncoderConfig):
+ super().__init__(config)
+ self.gradient_checkpointing = False
+
+ self.dropout = config.dropout
+ self.dropout_positions = config.dropout_positions
+ self.layerdrop = config.layerdrop
+
+ self.subsampler = LasrEncoderSubsampling(config)
+ self.rotary_emb = LasrEncoderRotaryEmbedding(config)
+ self.layers = nn.ModuleList(
+ [LasrEncoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.out_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False)
+
+ self.post_init()
+
+ @auto_docstring
+ @check_model_inputs()
+ @can_return_tuple
+ def forward(
+ self,
+ input_features: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutput:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, LasrEncoder
+ >>> from datasets import load_dataset, Audio
+
+ >>> model_id = TODO
+ >>> processor = AutoProcessor.from_pretrained(model_id)
+ >>> encoder = ParakeetEncoder.from_pretrained(model_id)
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
+
+ >>> inputs = processor(ds[0]["audio"]["array"])
+ >>> encoder_outputs = encoder(**inputs)
+
+ >>> print(encoder_outputs.last_hidden_state.shape)
+ ```
+ """
+
+ hidden_states = self.subsampler(input_features)
+ cos, sin = self.rotary_emb(
+ hidden_states, torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
+ )
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ cos = nn.functional.dropout(cos, p=self.dropout_positions, training=self.training)
+ sin = nn.functional.dropout(sin, p=self.dropout_positions, training=self.training)
+
+ if attention_mask is not None:
+ attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
+
+ attention_mask = create_bidirectional_mask(
+ config=self.config,
+ input_embeds=hidden_states,
+ attention_mask=attention_mask,
+ )
+
+ for encoder_layer in self.layers:
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ to_drop = False
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop: # skip the layer
+ to_drop = True
+
+ if not to_drop:
+ hidden_states = encoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_embeddings=(cos, sin),
+ **kwargs,
+ )
+
+ hidden_states = self.out_norm(hidden_states)
+
+ return BaseModelOutput(last_hidden_state=hidden_states)
+
+
+@dataclass
+class LasrGenerateOutput(ModelOutput):
+ """
+ Outputs of Lasr models.
+
+ Args:
+ sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
+ if all batches finished early due to the `eos_token_id`.
+ logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
+ Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
+ at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
+ each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
+ attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
+ Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
+ hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
+ Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
+ """
+
+ sequences: torch.LongTensor
+ logits: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
+
+
+@auto_docstring(
+ custom_intro="""
+ Lasr Encoder with a Connectionist Temporal Classification (CTC) head.
+ """
+)
+class LasrForCTC(LasrPreTrainedModel):
+ config: LasrCTCConfig
+
+ def __init__(self, config: LasrCTCConfig):
+ super().__init__(config)
+ self.encoder = LasrEncoder(config.encoder_config)
+ # Conv rather than linear to be consistent with NeMO decoding layer
+ self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1)
+
+ self.post_init()
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_features: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutput:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, LasrForCTC
+ >>> from datasets import load_dataset, Audio
+
+ >>> model_id = "nvidia/lasr-ctc-1.1b"
+ >>> processor = AutoProcessor.from_pretrained(model_id)
+ >>> model = LasrForCTC.from_pretrained(model_id)
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
+
+ >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
+ >>> outputs = model(**inputs)
+
+ >>> print(outputs.loss)
+ ```"""
+
+ encoder_outputs = self.encoder(
+ input_features=input_features,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+
+ hidden_states = encoder_outputs.last_hidden_state
+ logits = self.ctc_head(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ loss = None
+ if labels is not None:
+ # retrieve loss input_lengths from attention_mask
+ attention_mask = (
+ attention_mask if attention_mask is not None else torch.ones_like(input_features, dtype=torch.long)
+ )
+ input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
+
+ # assuming that padded tokens are filled with -100
+ # when not being attended to
+ labels_mask = labels != self.config.pad_token_id
+ target_lengths = labels_mask.sum(-1)
+ flattened_targets = labels.masked_select(labels_mask)
+
+ # ctc_loss doesn't support fp16
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
+
+ with torch.backends.cudnn.flags(enabled=False):
+ loss = nn.functional.ctc_loss(
+ log_probs,
+ flattened_targets,
+ input_lengths,
+ target_lengths,
+ blank=self.config.pad_token_id,
+ reduction=self.config.ctc_loss_reduction,
+ zero_infinity=self.config.ctc_zero_infinity,
+ )
+
+ return CausalLMOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ input_features: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ return_dict_in_generate: bool = False,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[LasrGenerateOutput, torch.LongTensor]:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, LasrForCTC
+ >>> from datasets import load_dataset, Audio
+
+ >>> model_id = TODO
+ >>> processor = AutoProcessor.from_pretrained(model_id)
+ >>> model = LasrForCTC.from_pretrained(model_id)
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
+
+ >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
+ >>> predicted_ids = model.generate(**inputs)
+ >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
+
+ >>> print(transcription)
+ ```
+ """
+ kwargs["return_dict"] = True
+ outputs: CausalLMOutput = self.forward(
+ input_features=input_features,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+
+ # greedy decoding
+ sequences = outputs.logits.argmax(dim=-1)
+
+ # mask out padded tokens
+ if attention_mask is not None:
+ attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequences.shape[1])
+ sequences[~attention_mask] = self.config.pad_token_id
+
+ if return_dict_in_generate:
+ return LasrGenerateOutput(
+ sequences=sequences,
+ logits=outputs.logits,
+ attentions=outputs.attentions,
+ hidden_states=outputs.hidden_states,
+ )
+
+ return sequences
+
+
+__all__ = ["LasrForCTC", "LasrEncoder", "LasrPreTrainedModel"]
diff --git a/src/transformers/models/lasr/modular_lasr.py b/src/transformers/models/lasr/modular_lasr.py
new file mode 100644
index 000000000000..c02b2ae0f1c3
--- /dev/null
+++ b/src/transformers/models/lasr/modular_lasr.py
@@ -0,0 +1,569 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team and Google LLC. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+from collections.abc import Callable
+from typing import Optional, Union
+
+import torch
+from tokenizers import Tokenizer
+from tokenizers.models import Unigram
+from torch import nn
+
+from ...masking_utils import create_bidirectional_mask
+from ...modeling_outputs import BaseModelOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...tokenization_utils_tokenizers import TokenizersBackend
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.generic import check_model_inputs
+from ..llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward
+from ..parakeet.configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig
+from ..parakeet.modeling_parakeet import (
+ ParakeetEncoderBlock,
+ ParakeetEncoderConvolutionModule,
+ ParakeetForCTC,
+ ParakeetPreTrainedModel,
+)
+from ..parakeet.processing_parakeet import ParakeetProcessor
+from ..t5.tokenization_t5 import T5Tokenizer
+
+
+class LasrTokenizer(T5Tokenizer, TokenizersBackend):
+ def __init__(
+ self,
+ eos_token="",
+ unk_token="",
+ pad_token="",
+ extra_ids=100,
+ additional_special_tokens=None,
+ vocab=None,
+ vocab_file=None,
+ **kwargs,
+ ):
+ super().__init__(
+ eos_token=eos_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ extra_ids=extra_ids,
+ additional_special_tokens=additional_special_tokens,
+ vocab=vocab,
+ vocab_file=vocab_file,
+ **kwargs,
+ )
+ self._tokenizer = Tokenizer(
+ Unigram(
+ self._vocab_scores,
+ unk_id=3,
+ byte_fallback=False,
+ )
+ )
+
+ def _decode(
+ self,
+ token_ids: Union[int, list[int]],
+ skip_special_tokens: bool = False,
+ clean_up_tokenization_spaces: Optional[bool] = None,
+ group_tokens: bool = True,
+ **kwargs,
+ ) -> str:
+ if isinstance(token_ids, int):
+ token_ids = [token_ids]
+ if group_tokens:
+ token_ids = [token_group[0] for token_group in itertools.groupby(token_ids)]
+
+ # for CTC we filter out the blank token, which is the pad token
+ token_ids = [token for token in token_ids if token != self.pad_token_id]
+
+ return TokenizersBackend._decode(
+ self,
+ token_ids=token_ids,
+ skip_special_tokens=skip_special_tokens,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+
+
+class LasrProcessor(ParakeetProcessor):
+ tokenizer_class = "ParakeetTokenizerFast"
+
+
+class LasrEncoderConfig(ParakeetEncoderConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LasrEncoder`]. It is used to instantiate a
+ `LasrEncoder` model according to the specified arguments, defining the model architecture.
+
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 512):
+ Dimension of the layers and the hidden states.
+ num_hidden_layers (`int`, *optional*, defaults to 17):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 2048):
+ Dimension of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the encoder and pooler.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use bias in the attention layers.
+ convolution_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use bias in convolutions of the conformer's convolution module.
+ conv_kernel_size (`int`, *optional*, defaults to 32):
+ The kernel size of the convolution layers in the Conformer block.
+ subsampling_conv_channels (`int`, *optional*, defaults to 256):
+ The number of channels in the subsampling convolution layers.
+ subsampling_conv_kernel_size (`int`, *optional*, defaults to 5):
+ The kernel size of the subsampling convolution layers.
+ subsampling_conv_stride (`int`, *optional*, defaults to 2):
+ The stride of the subsampling convolution layers.
+ num_mel_bins (`int`, *optional*, defaults to 128):
+ Number of mel features.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for all fully connected layers in the embeddings, encoder, and pooler.
+ dropout_positions (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the positions in the input sequence.
+ layerdrop (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the layers in the encoder.
+ activation_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for activations inside the fully connected layer.
+ attention_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention layers.
+ max_position_embeddings (`int`, *optional*, defaults to 10000):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ feed_forward_residual_weights (`tuple[float, float]`, *optional*, defaults to `[1.5, 0.5]`):
+ The residual weights for the feed forward layers.
+ conv_residual_weights (`tuple[float, float]`, *optional*, defaults to `[2.0, 1.0]`):
+ The residual weights for the convolution layers.
+ batch_norm_momentum (`float`, *optional*, defaults to 0.01):
+ The momentum for the batch normalization layers.
+ rope_parameters (`RopeParameters`, *optional*):
+ Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
+ a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
+ with longer `max_position_embeddings`.
+
+ Example:
+ ```python
+ >>> from transformers import LasrEncoderModel, LasrEncoderConfig
+
+ >>> # Initializing a `LasrEncoder` configuration
+ >>> configuration = LasrEncoderConfig()
+
+ >>> # Initializing a model from the configuration
+ >>> model = LasrEncoderModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+
+ This configuration class is based on the LasrEncoder architecture from Google Health AI. You can find more details
+ and pre-trained models at [TODO/TODO](https://huggingface.co/TODO/TODO).
+ """
+
+ def __init__(
+ self,
+ hidden_size=512,
+ num_hidden_layers=17,
+ num_attention_heads=8,
+ intermediate_size=2048,
+ hidden_act="silu",
+ attention_bias=False,
+ convolution_bias=False,
+ conv_kernel_size=32,
+ subsampling_conv_channels=256,
+ subsampling_conv_kernel_size=5,
+ subsampling_conv_stride=2,
+ num_mel_bins=128,
+ dropout=0.1,
+ dropout_positions=0.0,
+ layerdrop=0.1,
+ activation_dropout=0.1,
+ attention_dropout=0.1,
+ max_position_embeddings=10000,
+ initializer_range=0.02,
+ layer_norm_eps=1e-6,
+ feed_forward_residual_weights=[1.5, 0.5],
+ conv_residual_weights=[2.0, 1.0],
+ batch_norm_momentum=0.01,
+ rope_parameters=None,
+ **kwargs,
+ ):
+ self.rope_parameters = rope_parameters
+ self.layer_norm_eps = layer_norm_eps
+ self.feed_forward_residual_weights = feed_forward_residual_weights
+ self.conv_residual_weights = conv_residual_weights
+ self.batch_norm_momentum = batch_norm_momentum
+
+ super().__init__(
+ hidden_size=hidden_size,
+ num_hidden_layers=num_hidden_layers,
+ num_attention_heads=num_attention_heads,
+ intermediate_size=intermediate_size,
+ hidden_act=hidden_act,
+ attention_bias=attention_bias,
+ convolution_bias=convolution_bias,
+ conv_kernel_size=conv_kernel_size,
+ subsampling_conv_channels=subsampling_conv_channels,
+ num_mel_bins=num_mel_bins,
+ subsampling_conv_kernel_size=subsampling_conv_kernel_size,
+ subsampling_conv_stride=subsampling_conv_stride,
+ dropout=dropout,
+ dropout_positions=dropout_positions,
+ layerdrop=layerdrop,
+ activation_dropout=activation_dropout,
+ attention_dropout=attention_dropout,
+ max_position_embeddings=max_position_embeddings,
+ initializer_range=initializer_range,
+ **kwargs,
+ )
+
+ del self.subsampling_factor
+ del self.scale_input
+
+
+class LasrCTCConfig(ParakeetCTCConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LasrForCTC`]. It is used to instantiate a
+ Lasr CTC model according to the specified arguments, defining the model architecture.
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+ Args:
+ vocab_size (`int`, *optional*, defaults to 512):
+ Vocabulary size of the model.
+ ctc_loss_reduction (`str`, *optional*, defaults to `"mean"`):
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
+ instance of [`LasrForCTC`].
+ ctc_zero_infinity (`bool`, *optional*, defaults to `True`):
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
+ of [`LasrForCTC`].
+ encoder_config (`Union[dict, LasrEncoderConfig]`, *optional*):
+ The config object or dictionary of the encoder.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id. Also used as blank token id.
+ Example:
+ ```python
+ >>> from transformers import LasrForCTC, LasrCTCConfig
+ >>> # Initializing a Lasr configuration
+ >>> configuration = LasrCTCConfig()
+ >>> # Initializing a model from the configuration
+ >>> model = LasrForCTC(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ This configuration class is based on the Lasr CTC architecture from Google Health AI. You can find more details
+ and pre-trained models at [TODO/TODO](https://huggingface.co/TODO/TODO).
+ """
+
+ def __init__(
+ self,
+ vocab_size=512,
+ ctc_loss_reduction="mean",
+ ctc_zero_infinity=True,
+ encoder_config: Union[dict, LasrEncoderConfig] = None,
+ pad_token_id=0,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_size=vocab_size,
+ ctc_loss_reduction=ctc_loss_reduction,
+ ctc_zero_infinity=ctc_zero_infinity,
+ encoder_config=encoder_config,
+ pad_token_id=pad_token_id,
+ **kwargs,
+ )
+
+
+class LasrEncoderSubsampling(nn.Module):
+ def __init__(self, config: LasrEncoderConfig):
+ super().__init__()
+ self.dense_0 = nn.Linear(config.num_mel_bins, config.hidden_size)
+ self.conv_0 = nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=config.subsampling_conv_kernel_size,
+ stride=config.subsampling_conv_stride,
+ )
+ self.conv_1 = nn.Conv1d(
+ config.hidden_size,
+ config.subsampling_conv_channels,
+ kernel_size=config.subsampling_conv_kernel_size,
+ stride=config.subsampling_conv_stride,
+ )
+ self.dense_1 = nn.Linear(config.subsampling_conv_channels, config.hidden_size)
+ self.act_fn = nn.ReLU()
+
+ def forward(self, input_features: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.act_fn(self.dense_0(input_features))
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.act_fn(self.conv_0(hidden_states))
+ hidden_states = self.act_fn(self.conv_1(hidden_states))
+ hidden_states = hidden_states.transpose(1, 2)
+ return self.dense_1(hidden_states)
+
+
+class LasrEncoderRotaryEmbedding(LlamaRotaryEmbedding): ...
+
+
+class LasrEncoderAttention(LlamaAttention):
+ def __init__(self, config: LasrEncoderConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class LasrEncoderConvolutionModule(ParakeetEncoderConvolutionModule):
+ def __init__(self, config: LasrEncoderConfig, module_config=None):
+ super().__init__(config, module_config)
+ self.padding = "same"
+ self.norm = nn.BatchNorm1d(config.hidden_size, momentum=config.batch_norm_momentum)
+
+
+class LasrEncoderBlock(ParakeetEncoderBlock):
+ def __init__(self, config: LasrEncoderConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+
+ self.feed_forward_residual_weights = config.feed_forward_residual_weights
+ self.conv_residual_weights = config.conv_residual_weights
+
+ self.norm_feed_forward1 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
+ self.norm_self_att = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
+ self.norm_conv = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
+ self.norm_feed_forward2 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
+ self.norm_out = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.feed_forward1(self.norm_feed_forward1(hidden_states))
+ hidden_states = (
+ self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
+ )
+
+ normalized_hidden_states = self.norm_self_att(hidden_states)
+ attn_output, _ = self.self_attn(
+ hidden_states=normalized_hidden_states,
+ attention_mask=attention_mask,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = hidden_states + attn_output
+
+ conv_output = self.conv(self.norm_conv(hidden_states), attention_mask=attention_mask)
+ hidden_states = self.conv_residual_weights[0] * hidden_states + self.conv_residual_weights[1] * conv_output
+
+ residual = hidden_states
+ hidden_states = self.feed_forward2(self.norm_feed_forward2(hidden_states))
+ hidden_states = (
+ self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
+ )
+
+ hidden_states = self.norm_out(hidden_states)
+
+ return hidden_states
+
+
+class LasrPreTrainedModel(ParakeetPreTrainedModel):
+ def _init_weights(self, module):
+ PreTrainedModel._init_weights(module)
+
+ def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
+ encoder_config = self.config.encoder_config if isinstance(self.config, LasrCTCConfig) else self.config
+ kernel_size = encoder_config.subsampling_conv_kernel_size
+ stride = encoder_config.subsampling_conv_stride
+
+ num_layers = 2
+ for _ in range(num_layers):
+ input_lengths = (input_lengths - kernel_size) // stride + 1
+
+ return input_lengths
+
+
+@auto_docstring(
+ custom_intro="""
+ The LasrEncoder model, based on the Conformer architecture](https://arxiv.org/abs/2005.08100).
+ """
+)
+class LasrEncoder(LasrPreTrainedModel):
+ config: LasrEncoderConfig
+ base_model_prefix = "encoder"
+
+ def __init__(self, config: LasrEncoderConfig):
+ super().__init__(config)
+ self.gradient_checkpointing = False
+
+ self.dropout = config.dropout
+ self.dropout_positions = config.dropout_positions
+ self.layerdrop = config.layerdrop
+
+ self.subsampler = LasrEncoderSubsampling(config)
+ self.rotary_emb = LasrEncoderRotaryEmbedding(config)
+ self.layers = nn.ModuleList(
+ [LasrEncoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.out_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False)
+
+ self.post_init()
+
+ @auto_docstring
+ @check_model_inputs()
+ @can_return_tuple
+ def forward(
+ self,
+ input_features: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutput:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, LasrEncoder
+ >>> from datasets import load_dataset, Audio
+
+ >>> model_id = TODO
+ >>> processor = AutoProcessor.from_pretrained(model_id)
+ >>> encoder = ParakeetEncoder.from_pretrained(model_id)
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
+
+ >>> inputs = processor(ds[0]["audio"]["array"])
+ >>> encoder_outputs = encoder(**inputs)
+
+ >>> print(encoder_outputs.last_hidden_state.shape)
+ ```
+ """
+
+ hidden_states = self.subsampler(input_features)
+ cos, sin = self.rotary_emb(
+ hidden_states, torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
+ )
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ cos = nn.functional.dropout(cos, p=self.dropout_positions, training=self.training)
+ sin = nn.functional.dropout(sin, p=self.dropout_positions, training=self.training)
+
+ if attention_mask is not None:
+ attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
+
+ attention_mask = create_bidirectional_mask(
+ config=self.config,
+ input_embeds=hidden_states,
+ attention_mask=attention_mask,
+ )
+
+ for encoder_layer in self.layers:
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ to_drop = False
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop: # skip the layer
+ to_drop = True
+
+ if not to_drop:
+ hidden_states = encoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_embeddings=(cos, sin),
+ **kwargs,
+ )
+
+ hidden_states = self.out_norm(hidden_states)
+
+ return BaseModelOutput(last_hidden_state=hidden_states)
+
+
+class LasrForCTC(ParakeetForCTC):
+ def generate(**super_kwargs):
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, LasrForCTC
+ >>> from datasets import load_dataset, Audio
+
+ >>> model_id = TODO
+ >>> processor = AutoProcessor.from_pretrained(model_id)
+ >>> model = LasrForCTC.from_pretrained(model_id)
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
+
+ >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
+ >>> predicted_ids = model.generate(**inputs)
+ >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
+
+ >>> print(transcription)
+ ```
+ """
+ return super().generate(**super_kwargs)
+
+
+__all__ = [
+ "LasrForCTC",
+ "LasrEncoder",
+ "LasrPreTrainedModel",
+ "LasrProcessor",
+ "LasrEncoderConfig",
+ "LasrCTCConfig",
+ "LasrTokenizer",
+]
diff --git a/src/transformers/models/lasr/processing_lasr.py b/src/transformers/models/lasr/processing_lasr.py
new file mode 100644
index 000000000000..3396986866e2
--- /dev/null
+++ b/src/transformers/models/lasr/processing_lasr.py
@@ -0,0 +1,96 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/lasr/modular_lasr.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_lasr.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team and Google LLC. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+from ...audio_utils import AudioInput, make_list_of_audio
+from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class LasrProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "audio_kwargs": {
+ "sampling_rate": 16000,
+ "padding": "longest",
+ "return_attention_mask": True,
+ },
+ "text_kwargs": {
+ "padding": True,
+ "padding_side": "right",
+ "add_special_tokens": False,
+ },
+ "common_kwargs": {"return_tensors": "pt"},
+ }
+
+
+class LasrProcessor(ProcessorMixin):
+ tokenizer_class = "ParakeetTokenizerFast"
+
+ def __init__(self, feature_extractor, tokenizer):
+ super().__init__(feature_extractor, tokenizer)
+
+ def __call__(
+ self,
+ audio: AudioInput,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput], None] = None,
+ sampling_rate: Optional[int] = None,
+ **kwargs: Unpack[LasrProcessorKwargs],
+ ):
+ audio = make_list_of_audio(audio)
+
+ output_kwargs = self._merge_kwargs(
+ LasrProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ if sampling_rate is None:
+ logger.warning_once(
+ f"You've provided audio without specifying the sampling rate. It will be assumed to be {output_kwargs['audio_kwargs']['sampling_rate']}, which can result in silent errors."
+ )
+ elif sampling_rate != output_kwargs["audio_kwargs"]["sampling_rate"]:
+ raise ValueError(
+ f"The sampling rate of the audio ({sampling_rate}) does not match the sampling rate of the processor ({output_kwargs['audio_kwargs']['sampling_rate']}). Please provide resampled the audio to the expected sampling rate."
+ )
+
+ if audio is not None:
+ inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
+ if text is not None:
+ encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
+
+ if text is None:
+ return inputs
+ else:
+ inputs["labels"] = encodings["input_ids"]
+ return inputs
+
+ @property
+ def model_input_names(self):
+ feature_extractor_input_names = self.feature_extractor.model_input_names
+ return feature_extractor_input_names + ["labels"]
+
+
+__all__ = ["LasrProcessor"]
diff --git a/src/transformers/models/lasr/tokenization_lasr.py b/src/transformers/models/lasr/tokenization_lasr.py
new file mode 100644
index 000000000000..b88a3a0f9b57
--- /dev/null
+++ b/src/transformers/models/lasr/tokenization_lasr.py
@@ -0,0 +1,190 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/lasr/modular_lasr.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_lasr.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team and Google LLC. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import re
+from typing import Optional, Union
+
+from tokenizers import Tokenizer, decoders, pre_tokenizers, processors
+from tokenizers.models import Unigram
+
+from ...tokenization_utils_tokenizers import TokenizersBackend
+
+
+VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}
+
+
+class LasrTokenizer(TokenizersBackend):
+ """
+ Construct a LASR tokenizer (backed by HuggingFace's *tokenizers* library). Based on
+ [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models).
+
+ This tokenizer inherits from [`TokenizersBackend`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`, *optional*):
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
+ contains the vocabulary necessary to instantiate a tokenizer.
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ extra_ids (`int`, *optional*, defaults to 100):
+ Add a number of extra ids added to the vocabulary for use as sentinels. These tokens are accessible as
+ "" where "{%d}" is a number between 0 and extra_ids-1. These tokens can be retrieved by
+ calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids method
+ additional_special_tokens (`list[str]`, *optional*):
+ Additional special tokens used by the tokenizer.
+ vocab (`dict`, *optional*):
+ Custom vocabulary dict. If not provided, a minimal vocabulary is created using the special tokens.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = None
+
+ def __init__(
+ self,
+ eos_token="",
+ unk_token="",
+ pad_token="",
+ extra_ids=100,
+ additional_special_tokens=None,
+ vocab=None,
+ vocab_file=None,
+ **kwargs,
+ ):
+ self.vocab_file = vocab_file
+ self._extra_ids = extra_ids
+
+ # Handle extra_ids and additional_special_tokens
+ if additional_special_tokens is not None:
+ extra_tokens = [x for x in additional_special_tokens if "" for i in range(extra_ids)]
+ elif extra_ids > 0 and extra_ids != len(extra_tokens):
+ raise ValueError(
+ f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
+ " provided to LasrTokenizer. In this case the additional_special_tokens must include the extra_ids"
+ " tokens"
+ )
+ else:
+ extra_tokens = [f"" for i in range(extra_ids)]
+ additional_special_tokens = extra_tokens
+
+ # LASR vocab structure: =0, =1, =2, then regular vocab, then extra_ids in reverse
+ if vocab is not None:
+ self._vocab_scores = vocab
+ else:
+ self._vocab_scores = [
+ (str(pad_token), 0.0),
+ (str(eos_token), 0.0),
+ (str(unk_token), 0.0),
+ ("▁", -2.0), # Space token
+ ]
+ for i in range(extra_ids - 1, -1, -1):
+ self._vocab_scores.append((f"", 0.0))
+ self._tokenizer = Tokenizer(
+ Unigram(
+ self._vocab_scores,
+ unk_id=3,
+ byte_fallback=False,
+ )
+ )
+
+ self._tokenizer.normalizer = None
+
+ self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
+ [
+ pre_tokenizers.WhitespaceSplit(),
+ pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True),
+ ]
+ )
+
+ self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)
+
+ tokenizer_object = self._tokenizer
+
+ super().__init__(
+ tokenizer_object=tokenizer_object,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ extra_ids=extra_ids,
+ additional_special_tokens=additional_special_tokens,
+ **kwargs,
+ )
+
+ self._tokenizer.post_processor = processors.TemplateProcessing(
+ single=["$A", ""],
+ pair=["$A", "", "$B", ""],
+ special_tokens=[
+ ("", self.eos_token_id),
+ ],
+ )
+
+ def get_sentinel_tokens(self):
+ """Get the list of sentinel tokens (extra_id tokens) from additional_special_tokens."""
+ return list(
+ set(filter(lambda x: bool(re.search(r"", x)) is not None, self.additional_special_tokens))
+ )
+
+ def get_sentinel_token_ids(self):
+ """Get the token IDs for sentinel tokens."""
+ return [self.convert_tokens_to_ids(token) for token in self.get_sentinel_tokens()]
+
+ def _decode(
+ self,
+ token_ids: Union[int, list[int]],
+ skip_special_tokens: bool = False,
+ clean_up_tokenization_spaces: Optional[bool] = None,
+ group_tokens: bool = True,
+ **kwargs,
+ ) -> str:
+ if isinstance(token_ids, int):
+ token_ids = [token_ids]
+ if group_tokens:
+ token_ids = [token_group[0] for token_group in itertools.groupby(token_ids)]
+
+ # for CTC we filter out the blank token, which is the pad token
+ token_ids = [token for token in token_ids if token != self.pad_token_id]
+
+ return super()._decode(
+ token_ids=token_ids,
+ skip_special_tokens=skip_special_tokens,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+
+
+__all__ = ["LasrTokenizer"]
diff --git a/src/transformers/models/parakeet/configuration_parakeet.py b/src/transformers/models/parakeet/configuration_parakeet.py
index 057259b04899..48ece84b7b84 100644
--- a/src/transformers/models/parakeet/configuration_parakeet.py
+++ b/src/transformers/models/parakeet/configuration_parakeet.py
@@ -121,9 +121,6 @@ def __init__(
initializer_range=0.02,
**kwargs,
):
- super().__init__(
- **kwargs,
- )
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
@@ -133,10 +130,7 @@ def __init__(
self.attention_bias = attention_bias
self.convolution_bias = convolution_bias
- if (conv_kernel_size - 1) % 2 != 0:
- raise ValueError(f"conv_kernel_size must be odd, got {conv_kernel_size}")
self.conv_kernel_size = conv_kernel_size
-
self.subsampling_conv_kernel_size = subsampling_conv_kernel_size
self.subsampling_conv_stride = subsampling_conv_stride
@@ -153,6 +147,10 @@ def __init__(
self.scale_input = scale_input
self.initializer_range = initializer_range
+ super().__init__(
+ **kwargs,
+ )
+
class ParakeetCTCConfig(PreTrainedConfig):
r"""
diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py
index 6a23ebb454a1..d5e8d8fdf8bf 100644
--- a/src/transformers/models/parakeet/modeling_parakeet.py
+++ b/src/transformers/models/parakeet/modeling_parakeet.py
@@ -155,7 +155,7 @@ def forward(self, hidden_states, attention_mask=None):
Args:
hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
- attention_mask (`torch.Tensor` of shape `(batch, 1, time)`): Attention mask.
+ attention_mask (`torch.Tensor` of shape `(batch, 1, time, time)`): Attention mask.
Returns:
`torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
@@ -171,7 +171,10 @@ def forward(self, hidden_states, attention_mask=None):
# Apply padding mask before convolution
if attention_mask is not None:
- all_masked_rows = torch.all(~attention_mask, dim=-1)
+ if attention_mask.dtype == torch.bool:
+ all_masked_rows = torch.all(~attention_mask, dim=2)
+ else:
+ all_masked_rows = torch.all(~(attention_mask == 0.0), dim=2)
hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
# 1D Depthwise Conv
diff --git a/src/transformers/models/parakeet/processing_parakeet.py b/src/transformers/models/parakeet/processing_parakeet.py
index 9d69f1458b60..1ac54ba75552 100644
--- a/src/transformers/models/parakeet/processing_parakeet.py
+++ b/src/transformers/models/parakeet/processing_parakeet.py
@@ -28,6 +28,7 @@ class ParakeetProcessorKwargs(ProcessingKwargs, total=False):
"audio_kwargs": {
"sampling_rate": 16000,
"padding": "longest",
+ "return_attention_mask": True,
},
"text_kwargs": {
"padding": True,
diff --git a/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py
index 6a23e0668083..8ae45c5104f3 100644
--- a/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py
+++ b/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py
@@ -365,7 +365,7 @@ def __init__(
self.rope_parameters = rope_parameters
super().__init__(
tie_word_embeddings=tie_word_embeddings,
- ignore_keys_at_rope_validation={"mrope"},
+ ignore_keys_at_rope_validation={"mrope_section"},
**kwargs,
)
@@ -713,7 +713,9 @@ def __init__(
layer_type_validation(self.layer_types, self.num_hidden_layers)
self.rope_parameters = rope_parameters
- super().__init__(tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope"}, **kwargs)
+ super().__init__(
+ tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **kwargs
+ )
class Qwen2_5OmniDiTConfig(PreTrainedConfig):
diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
index 3f0b62102644..2bad5d01d7bb 100644
--- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
+++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
@@ -399,7 +399,7 @@ def __init__(
self.rope_parameters = rope_parameters
super().__init__(
tie_word_embeddings=tie_word_embeddings,
- ignore_keys_at_rope_validation={"mrope"},
+ ignore_keys_at_rope_validation={"mrope_section"},
**kwargs,
)
@@ -747,7 +747,9 @@ def __init__(
layer_type_validation(self.layer_types, self.num_hidden_layers)
self.rope_parameters = rope_parameters
- super().__init__(tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope"}, **kwargs)
+ super().__init__(
+ tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **kwargs
+ )
class Qwen2_5OmniDiTConfig(PreTrainedConfig):
diff --git a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py
index 084b4d8c9ce6..8832400df55d 100644
--- a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py
+++ b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py
@@ -230,7 +230,7 @@ def __init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
- ignore_keys_at_rope_validation={"mrope"},
+ ignore_keys_at_rope_validation={"mrope_section"},
**kwargs,
)
diff --git a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py
index e4578375036f..8372690ef471 100644
--- a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py
+++ b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py
@@ -218,7 +218,7 @@ def __init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
- ignore_keys_at_rope_validation={"mrope"},
+ ignore_keys_at_rope_validation={"mrope_section"},
**kwargs,
)
diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
index c8cc28bc626f..63573208d36b 100644
--- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
+++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
@@ -460,6 +460,7 @@ def forward(
use_cache: bool = True,
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
_, seq_len, _ = input_states.shape
+ batch_size = input_states.shape[0]
y_branch = self.linear_y(input_states)
y_branch = self.act_fn(y_branch)
@@ -468,6 +469,17 @@ def forward(
x_branch = x_branch.transpose(1, 2)
if use_cache:
+ # Check if cache needs initialization (None or batch size mismatch)
+ if self.conv1d_state is None or self.conv1d_state.shape[0] != batch_size:
+ self.conv1d_state = torch.zeros(
+ (batch_size, self.hidden_size, self.conv1d_width - 1),
+ device=input_states.device,
+ dtype=input_states.dtype,
+ )
+ self.rg_lru.recurrent_states = torch.zeros(
+ (batch_size, self.lru_width), device=input_states.device, dtype=torch.float32
+ )
+
if cache_position.shape[0] != 1: # prefill
self.conv1d_state = nn.functional.pad(x_branch, (self.conv1d_width - x_branch.shape[-1] - 1, 0))
x_branch = self.conv_1d(x_branch)[..., :seq_len]
diff --git a/src/transformers/models/smolvlm/processing_smolvlm.py b/src/transformers/models/smolvlm/processing_smolvlm.py
index 98828b2d53a8..f41d44f9a6ab 100644
--- a/src/transformers/models/smolvlm/processing_smolvlm.py
+++ b/src/transformers/models/smolvlm/processing_smolvlm.py
@@ -27,13 +27,6 @@
from ...video_utils import VideoInput
-if is_vision_available():
- from .video_processing_smolvlm import (
- DEFAULT_MEDIA_OUTTRO,
- DEFAULT_VIDEO_INTRO,
- FRAME_TIMESTAMP_MESSAGE,
- )
-
if is_vision_available():
from .video_processing_smolvlm import (
DEFAULT_MEDIA_OUTTRO,
diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py
index 79071edad095..d8ed32213309 100755
--- a/src/transformers/pipelines/__init__.py
+++ b/src/transformers/pipelines/__init__.py
@@ -18,7 +18,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union
-from huggingface_hub import model_info
+from huggingface_hub import is_offline_mode, model_info
from ..configuration_utils import PreTrainedConfig
from ..dynamic_module_utils import get_class_from_dynamic_module
@@ -38,7 +38,6 @@
extract_commit_hash,
find_adapter_config_file,
is_kenlm_available,
- is_offline_mode,
is_peft_available,
is_pyctcdecode_available,
is_torch_available,
diff --git a/src/transformers/pipelines/document_question_answering.py b/src/transformers/pipelines/document_question_answering.py
index 6feb678b1f98..42ada8602a85 100644
--- a/src/transformers/pipelines/document_question_answering.py
+++ b/src/transformers/pipelines/document_question_answering.py
@@ -146,7 +146,9 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- if self.tokenizer is not None and not self.tokenizer.__class__.__name__.endswith("Fast"):
+ if self.tokenizer is not None and not (
+ self.tokenizer.__class__.__name__.endswith("Fast") or self.tokenizer.backend == "tokenizers"
+ ):
raise ValueError(
"`DocumentQuestionAnsweringPipeline` requires a fast tokenizer, but a slow tokenizer "
f"(`{self.tokenizer.__class__.__name__}`) is provided."
diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py
index ca15054f5e66..da846e12da7f 100644
--- a/src/transformers/processing_utils.py
+++ b/src/transformers/processing_utils.py
@@ -28,7 +28,7 @@
import numpy as np
import typing_extensions
-from huggingface_hub import create_repo
+from huggingface_hub import create_repo, is_offline_mode
from huggingface_hub.dataclasses import validate_typed_dict
from huggingface_hub.errors import EntryNotFoundError
@@ -54,7 +54,6 @@
cached_file,
copy_func,
direct_transformers_import,
- is_offline_mode,
is_torch_available,
list_repo_templates,
logging,
@@ -696,14 +695,10 @@ def to_dict(self) -> dict[str, Any]:
# extra attributes to be kept
attrs_to_save += ["auto_map"]
- if "tokenizer" in output:
- del output["tokenizer"]
- if "qformer_tokenizer" in output:
- del output["qformer_tokenizer"]
- if "protein_tokenizer" in output:
- del output["protein_tokenizer"]
- if "char_tokenizer" in output:
- del output["char_tokenizer"]
+ for attribute in self.__class__.get_attributes():
+ if "tokenizer" in attribute and attribute in output:
+ del output[attribute]
+
if "chat_template" in output:
del output["chat_template"]
diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py
index 94fc34c1e70d..4d7b20cca7db 100644
--- a/src/transformers/quantizers/base.py
+++ b/src/transformers/quantizers/base.py
@@ -75,26 +75,14 @@ class HfQuantizer(ABC):
Attributes
quantization_config (`transformers.utils.quantization_config.QuantizationConfigMixin`):
The quantization config that defines the quantization parameters of your model that you want to quantize.
- modules_to_not_convert (`list[str]`, *optional*):
- The list of module names to not convert when quantizing the model.
- required_packages (`list[str]`, *optional*):
- The list of required pip packages to install prior to using the quantizer
requires_calibration (`bool`):
Whether the quantization method requires to calibrate the model before using it.
- requires_parameters_quantization (`bool`):
- Whether the quantization method requires to create a new Parameter. For example, for bitsandbytes, it is
- required to create a new xxxParameter in order to properly quantize the model.
"""
requires_calibration = False
- required_packages = None
- requires_parameters_quantization = False
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
self.quantization_config = quantization_config
-
- # -- Handle extra kwargs below --
- self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
self.pre_quantized = kwargs.pop("pre_quantized", True)
if not self.pre_quantized and self.requires_calibration:
@@ -157,53 +145,16 @@ def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "
return mapping[custom_dtype]
return param.element_size()
- def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
- """
- Override this method if you want to adjust the `missing_keys`.
-
- Args:
- missing_keys (`list[str]`, *optional*):
- The list of missing keys in the checkpoint compared to the state dict of the model
- """
- return missing_keys
-
- def update_expected_keys(self, model, expected_keys: list[str], loaded_keys: list[str]) -> list[str]:
- """
- Override this method if you want to adjust the `update_expected_keys`.
-
- Args:
- expected_keys (`list[str]`, *optional*):
- The list of the expected keys in the initialized model.
- loaded_keys (`list[str]`, *optional*):
- The list of the loaded keys in the checkpoint.
- """
- return expected_keys
-
- def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
- return unexpected_keys
-
def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]:
"""adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization"""
return max_memory
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
"""
- Check whether a given param needs quantization as defined by `create_quantized_param`.
+ Check whether a given param needs to be quantized.
"""
return False
- def create_quantized_param(self, *args, **kwargs):
- """
- Take needed components from state_dict (those from which `param_needs_quantization` is True) and create
- quantized param.
- It usually also load the new param directly in the `model`.
- Note: only applicable if requires_parameters_quantization == True.
- """
- if not self.requires_parameters_quantization:
- raise AttributeError(
- f"`.create_quantized_param()` method is not supported by quantizer class {self.__class__.__name__}."
- )
-
def validate_environment(self, *args, **kwargs):
"""
This method is used to potentially check for potential conflicts with arguments that are
@@ -263,6 +214,11 @@ def postprocess_model(self, model: "PreTrainedModel", **kwargs):
kwargs (`dict`, *optional*):
The keyword arguments that are passed along `_process_model_after_weight_loading`.
"""
+ model.config.quantization_config = self.quantization_config
+
+ if self.pre_quantized and getattr(self.quantization_config, "dequantize", False):
+ self.remove_quantization_config(model)
+
return self._process_model_after_weight_loading(model, **kwargs)
def remove_quantization_config(self, model):
@@ -285,13 +241,7 @@ def dequantize(self, model):
Note not all quantization schemes support this.
"""
model = self._dequantize(model)
-
- # Delete quantizer and quantization config
- del model.hf_quantizer
- del model.config.quantization_config
- del model.config._pre_quantization_dtype
- del model.quantization_method
- model.is_quantized = False
+ self.remove_quantization_config(model)
return model
@@ -353,10 +303,6 @@ def get_state_dict_and_metadata(self, model, safe_serialization=False):
"""Get state dict and metadata. Useful when we need to modify a bit the state dict due to quantization"""
return None, {}
- def update_state_dict_with_metadata(self, state_dict, metadata):
- """Update state dict with metadata. Default behaviour returns state_dict"""
- return state_dict
-
@abstractmethod
def is_serializable(self, safe_serialization=None): ...
diff --git a/src/transformers/quantizers/quantizer_aqlm.py b/src/transformers/quantizers/quantizer_aqlm.py
index ec3fd033b7b3..11f80af80717 100644
--- a/src/transformers/quantizers/quantizer_aqlm.py
+++ b/src/transformers/quantizers/quantizer_aqlm.py
@@ -39,12 +39,9 @@ class AqlmHfQuantizer(HfQuantizer):
"""
requires_calibration = True
- required_packages = ["aqlm"]
- optimum_quantizer = None
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
- self.quantization_config = quantization_config
def validate_environment(self, *args, **kwargs):
if not is_accelerate_available():
@@ -77,7 +74,6 @@ def _process_model_before_weight_loading(
quantization_config=self.quantization_config,
linear_weights_not_to_quantize=self.quantization_config.linear_weights_not_to_quantize,
)
- model.config.quantization_config = self.quantization_config
@property
def is_trainable(self) -> bool:
@@ -90,5 +86,5 @@ def is_trainable(self) -> bool:
)
return False
- def is_serializable(self, safe_serialization=None):
+ def is_serializable(self, **kwargs):
return True
diff --git a/src/transformers/quantizers/quantizer_auto_round.py b/src/transformers/quantizers/quantizer_auto_round.py
index faf2fea9d133..9b575ce35aee 100644
--- a/src/transformers/quantizers/quantizer_auto_round.py
+++ b/src/transformers/quantizers/quantizer_auto_round.py
@@ -36,7 +36,6 @@ class AutoRoundQuantizer(HfQuantizer):
# AutoRound requires data calibration - we support only inference
requires_calibration = True
- required_packages = ["auto_round"]
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
diff --git a/src/transformers/quantizers/quantizer_awq.py b/src/transformers/quantizers/quantizer_awq.py
index 9574f3e1fb34..a6affed5ea4d 100644
--- a/src/transformers/quantizers/quantizer_awq.py
+++ b/src/transformers/quantizers/quantizer_awq.py
@@ -40,8 +40,6 @@ class AwqQuantizer(HfQuantizer):
# AWQ requires data calibration - we support only inference
requires_calibration = True
- required_packages = ["awq", "accelerate"]
-
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
diff --git a/src/transformers/quantizers/quantizer_bitnet.py b/src/transformers/quantizers/quantizer_bitnet.py
index b82357d90e5b..5761d8fc3c7b 100644
--- a/src/transformers/quantizers/quantizer_bitnet.py
+++ b/src/transformers/quantizers/quantizer_bitnet.py
@@ -37,14 +37,10 @@ class BitNetHfQuantizer(HfQuantizer):
Check out the paper introducing this method: https://huggingface.co/papers/2402.17764
"""
- requires_parameters_quantization = False
requires_calibration = True
- required_packages = ["accelerate"]
-
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
- self.quantization_config = quantization_config
def validate_environment(self, *args, **kwargs):
if not is_accelerate_available():
@@ -62,8 +58,8 @@ def validate_environment(self, *args, **kwargs):
"You have loaded a BitNet model on CPU and have a CUDA device available, make sure to set "
"your model on a GPU device in order to run your model."
)
- elif device_map is not None:
- if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
+ elif isinstance(device_map, dict):
+ if len(device_map) > 1 and "cpu" in device_map.values() or "disk" in device_map.values():
raise ValueError(
"You are attempting to load a BitNet model with a device_map that contains a CPU or disk device."
"This is not supported. Please remove the CPU or disk device from the device_map."
diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py
index 3aa847633d95..dd4a06b8719f 100644
--- a/src/transformers/quantizers/quantizer_bnb_4bit.py
+++ b/src/transformers/quantizers/quantizer_bnb_4bit.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from collections import defaultdict
from typing import TYPE_CHECKING
from .base import HfQuantizer
@@ -38,34 +37,20 @@
import torch
from ..core_model_loading import WeightConverter
- from ..pytorch_utils import Conv1D
logger = logging.get_logger(__name__)
class Bnb4BitHfQuantizer(HfQuantizer):
"""
- 4-bit quantization from bitsandbytes quantization method:
- before loading: converts transformer layers into Linear4bit during loading: load 16bit weight and pass to the
- layer object after: quantizes individual weights in Linear4bit into 4bit at the first .cuda() call
- saving:
- from state dict, as usual; saves weights and `quant_state` components
- loading:
- need to locate `quant_state` components and pass to Param4bit constructor
+ 4-bit quantization from bitsandbytes quantization method
"""
- use_keep_in_fp32_modules = True
- requires_parameters_quantization = True
requires_calibration = False
- required_packages = ["bitsandbytes", "accelerate"]
-
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
- if self.quantization_config.llm_int8_skip_modules is not None:
- self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
-
# This describes the additional items that are saved on the state dict (on the params themselves)
self.bnb_keys = [
f"quant_state.bitsandbytes__{self.quantization_config.bnb_4bit_quant_type}",
@@ -90,17 +75,9 @@ def validate_environment(self, *args, **kwargs):
validate_bnb_backend_availability(raise_exception=True)
device_map = kwargs.get("device_map")
- if (
- device_map is not None
- and isinstance(device_map, dict)
- and not self.quantization_config.llm_int8_enable_fp32_cpu_offload
- ):
- device_map_without_lm_head = {
- key: device_map[key] for key in device_map if key not in self.modules_to_not_convert
- }
- if set(device_map.values()) == {"cpu"}:
- pass
- elif "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
+ if not self.quantization_config.llm_int8_enable_fp32_cpu_offload and isinstance(device_map, dict):
+ values = set(device_map.values())
+ if values != {"cpu"} and ("cpu" in values or "disk" in values):
raise ValueError(
"Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the "
"quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules "
@@ -117,13 +94,11 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization")
return CustomDtype.INT4
- def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
- return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.bnb_keys)]
-
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
import bitsandbytes as bnb
- # They are on the params themselves, so we cannot easily extract the module from the name
+ # TODO: maybe remove
+ # # They are on the params themselves, so we cannot easily extract the module from the name
if any(param_name.endswith(x) for x in self.bnb_keys):
return True
module, name = get_module_from_name(model, param_name)
@@ -142,71 +117,13 @@ def get_param_name(self, param_name: str) -> str:
)
return param_name
- def create_quantized_param(
- self,
- model: "PreTrainedModel",
- param_value: "torch.Tensor",
- param_name: str,
- target_device: "torch.device",
- **kwargs,
- ):
- import bitsandbytes as bnb
-
- full_name = param_name
-
- # update param name to get the weights instead of the quantized stats
- param_name = self.get_param_name(param_name)
- module, tensor_name = get_module_from_name(model, param_name)
-
- # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
- if isinstance(target_device, int) and is_torch_npu_available():
- target_device = f"npu:{target_device}"
-
- # construct `new_value` for the module._parameters[tensor_name]
- if self.pre_quantized:
- module_name = param_name.rsplit(".", 1)[0]
- # Save the states for later quantization when they are all gathered
- if not hasattr(self, "param_quant_stats"):
- self.param_quant_stats = defaultdict(dict)
- self.param_quant_stats[module_name].update({full_name: param_value})
-
- # We are ready for quantization in this case (note, the +1 is for the weight itself)
- if len(self.param_quant_stats[module_name]) == len(self.bnb_keys) + 1:
- weight = self.param_quant_stats[module_name].pop(f"{module_name}.weight")
- new_value = bnb.nn.Params4bit.from_prequantized(
- data=weight,
- quantized_stats=self.param_quant_stats[module_name],
- requires_grad=False,
- device=target_device,
- module=module,
- )
- # Set it
- module._parameters[tensor_name] = new_value
- # Delete the states
- del self.param_quant_stats[module_name]
- else:
- new_value = param_value.to("cpu")
- old_value = getattr(module, tensor_name)
-
- # Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
- # Since weights are saved in the correct "orientation", we skip transposing when loading.
- if issubclass(module.source_cls, Conv1D):
- new_value = new_value.T
-
- kwargs = old_value.__dict__
- kwargs.pop("_is_hf_initialized", None)
- new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)
-
- module._parameters[tensor_name] = new_value
-
- # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.adjust_max_memory
def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]:
# need more space for buffers that are created during quantization
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
return max_memory
- # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.update_dtype
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
+ # TODO: remove ? is it still true ? we will move to dtype = "auto" so it will likely be either fp16 or bf16
if dtype is None:
# We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
logger.info(
@@ -238,7 +155,6 @@ def update_device_map(self, device_map):
)
return device_map
- # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_before_weight_loading
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
@@ -248,23 +164,15 @@ def _process_model_before_weight_loading(
):
from ..integrations import replace_with_bnb_linear
- llm_int8_enable_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload
-
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.llm_int8_skip_modules, keep_in_fp32_modules
)
- # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
- if isinstance(device_map, dict) and len(device_map.keys()) > 1:
- keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
+ if self.quantization_config.llm_int8_enable_fp32_cpu_offload:
+ if isinstance(device_map, dict):
+ keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
+ self.modules_to_not_convert.extend(keys_on_cpu)
- if len(keys_on_cpu) > 0 and not llm_int8_enable_fp32_cpu_offload:
- raise ValueError(
- "If you want to offload some keys to `cpu` or `disk`, you need to set "
- "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be "
- " converted to 8-bit but kept in 32-bit."
- )
- self.modules_to_not_convert.extend(keys_on_cpu)
model = replace_with_bnb_linear(
model,
modules_to_not_convert=self.modules_to_not_convert,
@@ -272,15 +180,12 @@ def _process_model_before_weight_loading(
pre_quantized=self.pre_quantized,
)
- model.config.quantization_config = self.quantization_config
-
- # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_after_weight_loading with 8bit->4bit
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
model.is_loaded_in_4bit = True
model.is_4bit_serializable = self.is_serializable()
return model
- def is_serializable(self, safe_serialization=None):
+ def is_serializable(self, **kwargs):
return True
@property
diff --git a/src/transformers/quantizers/quantizer_bnb_8bit.py b/src/transformers/quantizers/quantizer_bnb_8bit.py
index 961c8e207b3f..582144e92220 100644
--- a/src/transformers/quantizers/quantizer_bnb_8bit.py
+++ b/src/transformers/quantizers/quantizer_bnb_8bit.py
@@ -25,6 +25,8 @@
is_accelerate_available,
is_bitsandbytes_available,
is_torch_available,
+ is_torch_hpu_available,
+ is_torch_npu_available,
is_torch_xpu_available,
logging,
)
@@ -35,34 +37,20 @@
import torch
from ..core_model_loading import WeightConverter
- from ..pytorch_utils import Conv1D
logger = logging.get_logger(__name__)
class Bnb8BitHfQuantizer(HfQuantizer):
"""
- 8-bit quantization from bitsandbytes quantization method:
- before loading: converts transformer layers into Linear8bitLt during loading: load 16bit weight and pass to the
- layer object after: quantizes individual weights in Linear8bitLt into 8bit at fitst .cuda() call
- saving:
- from state dict, as usual; saves weights and 'SCB' component
- loading:
- need to locate SCB component and pass to the Linear8bitLt object
+ 8-bit quantization from bitsandbytes quantization method
"""
- use_keep_in_fp32_modules = True
- requires_parameters_quantization = True
requires_calibration = False
- required_packages = ["bitsandbytes", "accelerate"]
-
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
- if self.quantization_config.llm_int8_skip_modules is not None:
- self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
-
def validate_environment(self, *args, **kwargs):
if not is_accelerate_available():
raise ImportError(
@@ -78,17 +66,9 @@ def validate_environment(self, *args, **kwargs):
validate_bnb_backend_availability(raise_exception=True)
device_map = kwargs.get("device_map")
- if (
- device_map is not None
- and isinstance(device_map, dict)
- and not self.quantization_config.llm_int8_enable_fp32_cpu_offload
- ):
- device_map_without_lm_head = {
- key: device_map[key] for key in device_map if key not in self.modules_to_not_convert
- }
- if set(device_map.values()) == {"cpu"}:
- pass
- elif "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
+ if not self.quantization_config.llm_int8_enable_fp32_cpu_offload and isinstance(device_map, dict):
+ values = set(device_map.values())
+ if values != {"cpu"} and ("cpu" in values or "disk" in values):
raise ValueError(
"Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the "
"quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules "
@@ -120,6 +100,10 @@ def update_device_map(self, device_map):
if device_map is None:
if torch.cuda.is_available():
device_map = {"": torch.cuda.current_device()}
+ elif is_torch_npu_available():
+ device_map = {"": f"npu:{torch.npu.current_device()}"}
+ elif is_torch_hpu_available():
+ device_map = {"": f"hpu:{torch.hpu.current_device()}"}
elif is_torch_xpu_available():
device_map = {"": torch.xpu.current_device()}
else:
@@ -132,61 +116,14 @@ def update_device_map(self, device_map):
return device_map
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
- if target_dtype != torch.int8:
- logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization")
return torch.int8
- def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
- bnb_keys = ["SCB", "weight_format"]
- return [k for k in unexpected_keys if not any(k.endswith(x) for x in bnb_keys)]
-
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
import bitsandbytes as bnb
module, name = get_module_from_name(model, param_name)
return isinstance(module, bnb.nn.Linear8bitLt) and name != "bias"
- def create_quantized_param(
- self,
- model: "PreTrainedModel",
- param_value: "torch.Tensor",
- param_name: str,
- target_device: "torch.device",
- **kwargs,
- ):
- import bitsandbytes as bnb
-
- module, tensor_name = get_module_from_name(model, param_name)
-
- if self.pre_quantized and not self.is_serializable():
- raise ValueError(
- "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
- "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
- )
- # Those 2 can only happen when self.pre_quantized == True
- if tensor_name == "SCB":
- setattr(module.weight, "SCB", param_value.to(target_device))
- return
- # It's not used, but it's getting serialized for BC reason...
- elif tensor_name == "weight_format":
- return
-
- # Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
- # Since weights are saved in the correct "orientation", we skip transposing when loading.
- if issubclass(module.source_cls, Conv1D) and not self.pre_quantized:
- param_value = param_value.T
-
- old_value = getattr(module, tensor_name)
- kwargs = old_value.__dict__
- kwargs.pop("_is_hf_initialized", None)
- # Need to pop SCB and reset it because of bnb internals that modifies its value when switching devices ...
- SCB = kwargs.pop("SCB", None)
- new_value = bnb.nn.Int8Params(param_value.to("cpu"), requires_grad=False, **kwargs).to(target_device)
- if SCB is not None:
- setattr(new_value, "SCB", SCB)
- # Set it to the module
- module._parameters[tensor_name] = new_value
-
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
model.is_loaded_in_8bit = True
model.is_8bit_serializable = self.is_serializable()
@@ -201,23 +138,14 @@ def _process_model_before_weight_loading(
):
from ..integrations import replace_with_bnb_linear
- llm_int8_enable_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload
-
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.llm_int8_skip_modules, keep_in_fp32_modules
)
- # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
- if isinstance(device_map, dict) and len(device_map.keys()) > 1:
- keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
-
- if len(keys_on_cpu) > 0 and not llm_int8_enable_fp32_cpu_offload:
- raise ValueError(
- "If you want to offload some keys to `cpu` or `disk`, you need to set "
- "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be "
- " converted to 8-bit but kept in 32-bit."
- )
- self.modules_to_not_convert.extend(keys_on_cpu)
+ if self.quantization_config.llm_int8_enable_fp32_cpu_offload:
+ if isinstance(device_map, dict):
+ keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
+ self.modules_to_not_convert.extend(keys_on_cpu)
model = replace_with_bnb_linear(
model,
@@ -226,9 +154,7 @@ def _process_model_before_weight_loading(
pre_quantized=self.pre_quantized,
)
- model.config.quantization_config = self.quantization_config
-
- def is_serializable(self, safe_serialization=None):
+ def is_serializable(self, **kwargs):
return True
@property
diff --git a/src/transformers/quantizers/quantizer_compressed_tensors.py b/src/transformers/quantizers/quantizer_compressed_tensors.py
index 803c55775214..3f70ca96379e 100644
--- a/src/transformers/quantizers/quantizer_compressed_tensors.py
+++ b/src/transformers/quantizers/quantizer_compressed_tensors.py
@@ -31,7 +31,6 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
"""
requires_calibration = True
- required_packages = ["compressed_tensors"]
def __init__(self, quantization_config: CompressedTensorsConfig, **kwargs):
super().__init__(quantization_config, **kwargs)
@@ -58,9 +57,6 @@ def validate_environment(self, *args, **kwargs):
"Using `compressed_tensors` quantized models requires the compressed-tensors library: "
"`pip install compressed-tensors`"
)
- if not is_torch_available():
- # torch already should be installed as part of compressed tensors
- raise ImportError("torch is required for using compressed-tensors quantization")
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
if dtype is None:
@@ -113,6 +109,6 @@ def is_qat_trainable(self) -> bool:
# models need to be decompressed carry out qat
return not self.run_compressed or not self.quantization_config.is_quantization_compressed
- def is_serializable(self, safe_serialization=None) -> bool:
+ def is_serializable(self, **kwargs) -> bool:
"""Models quantized using compressed tensors can be saved to disk"""
return True
diff --git a/src/transformers/quantizers/quantizer_eetq.py b/src/transformers/quantizers/quantizer_eetq.py
index 52efea9d00ee..235abb89fd5d 100644
--- a/src/transformers/quantizers/quantizer_eetq.py
+++ b/src/transformers/quantizers/quantizer_eetq.py
@@ -32,19 +32,13 @@
class EetqHfQuantizer(HfQuantizer):
"""
- 8-bit quantization from EETQ quantization method:
- before loading: converts transformer layers into W8A16Linear during loading: load 16bit weight and pass to the
- layer object after: quantizes individual weights in Linear8bitLt into 8bit at first .cuda() call
+ 8-bit quantization from EETQ quantization method
"""
- requires_parameters_quantization = True
requires_calibration = False
- required_packages = ["eetq", "accelerate"]
-
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
- self.quantization_config = quantization_config
def validate_environment(self, *args, **kwargs):
if not is_kernels_available():
@@ -62,8 +56,8 @@ def validate_environment(self, *args, **kwargs):
"You have loaded an EETQ model on CPU and have a CUDA device available, make sure to set "
"your model on a GPU device in order to run your model."
)
- elif device_map is not None:
- if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
+ elif isinstance(device_map, dict):
+ if len(device_map) > 1 and "cpu" in device_map.values() or "disk" in device_map.values():
raise ValueError(
"You are attempting to load an EETQ model with a device_map that contains a CPU or disk device."
" This is not supported. Please remove the CPU or disk device from the device_map."
@@ -111,9 +105,7 @@ def _process_model_before_weight_loading(
model, modules_to_not_convert=self.modules_to_not_convert, pre_quantized=self.pre_quantized
)
- model.config.quantization_config = self.quantization_config
-
- def is_serializable(self, safe_serialization=None):
+ def is_serializable(self, **kwargs):
return True
@property
diff --git a/src/transformers/quantizers/quantizer_fbgemm_fp8.py b/src/transformers/quantizers/quantizer_fbgemm_fp8.py
index ae186d3cbdbf..25e6e83e6cca 100644
--- a/src/transformers/quantizers/quantizer_fbgemm_fp8.py
+++ b/src/transformers/quantizers/quantizer_fbgemm_fp8.py
@@ -35,37 +35,23 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
FP8 quantization using fbgemm kernels
"""
- requires_parameters_quantization = True
requires_calibration = False
- required_packages = ["fbgemm-gpu", "accelerate"]
-
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
- self.quantization_config = quantization_config
def validate_environment(self, *args, **kwargs):
- if not is_torch_available():
- raise ImportError(
- "Using fbgemm fp8 quantization requires torch >= 2.1.0"
- "Please install the latest version of torch ( pip install --upgrade torch )"
- )
if not is_fbgemm_gpu_available():
raise ImportError(
"Using fbgemm fp8 quantization requires fbgemm-gpu library"
"Please install the latest version of fbgemm-gpu library by following : https://pytorch.org/FBGEMM/fbgemm_gpu-development/InstallationInstructions.html#fbgemm-gpu-install-libraries"
)
-
if not is_accelerate_available():
raise ImportError(
"Loading an FP8 quantized model requires accelerate (`pip install --upgrade accelerate`)"
)
-
- if not torch.cuda.is_available():
- raise RuntimeError("Using FP8 quantized models with fbgemm kernels requires a GPU")
-
compute_capability = torch.cuda.get_device_capability()
- major, minor = compute_capability
+ major, _ = compute_capability
if major < 9:
raise ValueError(
"FP8 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)"
@@ -77,12 +63,8 @@ def validate_environment(self, *args, **kwargs):
"You have loaded an FP8 model on CPU and have a CUDA device available, make sure to set "
"your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. "
)
- elif device_map is not None:
- if (
- not self.pre_quantized
- and isinstance(device_map, dict)
- and ("cpu" in device_map.values() or "disk" in device_map.values())
- ):
+ elif isinstance(device_map, dict):
+ if not self.pre_quantized and ("cpu" in device_map.values() or "disk" in device_map.values()):
raise ValueError(
"You are attempting to load an FP8 model with a device_map that contains a CPU or disk device."
"This is not supported when the model is quantized on the fly. "
@@ -101,7 +83,7 @@ def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
)
elif dtype == torch.float16:
raise ValueError(
- "You cannot use FP8 with dtype=torch.float16.We recommend you passing dtype=torch.bfloat16"
+ "You cannot use FP8 with dtype=torch.float16. We recommend you passing dtype=torch.bfloat16"
)
return dtype
@@ -122,76 +104,6 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **
return True
return False
- def create_quantized_param(
- self,
- model: "PreTrainedModel",
- param_value: "torch.Tensor",
- param_name: str,
- target_device: "torch.device",
- **kwargs,
- ):
- from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
-
- module, tensor_name = get_module_from_name(model, param_name)
-
- # Sanity checks
- if isinstance(module, FbgemmFp8Linear):
- if self.pre_quantized or tensor_name == "bias":
- if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
- raise ValueError("Expect quantized weights but got an unquantized weight")
- else:
- if tensor_name == "weight_scale":
- raise ValueError("Expect unquantized weights but got a quantized weight_scale")
- if isinstance(module, FbgemmFp8Llama4TextExperts):
- if not (self.pre_quantized or tensor_name == "bias"):
- if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale":
- raise ValueError("Expect unquantized weights but got a quantized weight_scale")
-
- if isinstance(module, FbgemmFp8Llama4TextExperts):
- if tensor_name == "gate_up_proj":
- # Process each expert separately
- # Transpose the second and third dimension
- transposed_param = param_value.transpose(1, 2)
-
- # Reshape to 2D for quantization
- original_shape = transposed_param.shape
- flattened_param = transposed_param.reshape(-1, original_shape[-1])
-
- # Quantize using per row instead of per column
- new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
-
- # Reshape back to original dimensions
- new_value = new_value_flat.reshape(original_shape)
- new_value = new_value.transpose(1, 2)
- weight_scale = weight_scale_flat.reshape(original_shape[0], 1, original_shape[1])
- elif tensor_name == "down_proj":
- # Process each expert separately
- # Transpose the weights for proper quantization
- transposed_param = param_value.transpose(1, 2)
-
- # Reshape to 2D for quantization
- original_shape = transposed_param.shape
- flattened_param = transposed_param.reshape(-1, original_shape[-1])
-
- # Quantize using per column
- new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
-
- # Reshape back to original dimensions
- new_value = new_value_flat.reshape(original_shape)
- new_value = new_value.transpose(1, 2)
- weight_scale = weight_scale_flat.reshape(original_shape[0], original_shape[1], 1)
-
- module._parameters[f"{tensor_name}_scale"] = torch.nn.Parameter(weight_scale.to(target_device))
- else:
- new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(param_value)
- module._parameters[f"{tensor_name}_scale"] = torch.nn.Parameter(
- weight_scale.view(weight_scale.shape[0], 1).to(target_device)
- )
-
- module._parameters[tensor_name] = torch.nn.Parameter(new_value.to(target_device))
-
- del param_name
-
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
@@ -200,38 +112,19 @@ def _process_model_before_weight_loading(
):
from ..integrations import replace_with_fbgemm_fp8_linear
- tp_plan = model._tp_plan
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)
- config = model.config
model = replace_with_fbgemm_fp8_linear(
model,
modules_to_not_convert=self.modules_to_not_convert,
quantization_config=self.quantization_config,
pre_quantized=self.pre_quantized,
- config=config,
- tp_plan=tp_plan,
+ config=model.config,
+ tp_plan=model._tp_plan,
)
- model.config.quantization_config = self.quantization_config
-
- def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
- from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
-
- not_missing_keys = []
- for name, module in model.named_modules():
- if isinstance(module, (FbgemmFp8Linear, FbgemmFp8Llama4TextExperts)):
- for missing in missing_keys:
- if (
- (name in missing or name in f"{prefix}.{missing}")
- and not missing.endswith(".weight")
- and not missing.endswith(".bias")
- ):
- not_missing_keys.append(missing)
- return [k for k in missing_keys if k not in not_missing_keys]
-
def update_tp_plan(self, config):
if "Llama4" in config.__class__.__name__:
text_plan = {
@@ -279,7 +172,7 @@ def update_tp_plan(self, config):
return config
- def is_serializable(self, safe_serialization=None):
+ def is_serializable(self, **kwargs):
return True
@property
diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py
index 75f49e6c9da1..eb0c38f5e137 100644
--- a/src/transformers/quantizers/quantizer_finegrained_fp8.py
+++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py
@@ -20,26 +20,20 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
Supports both e4m3fn formats based on platform.
"""
- requires_parameters_quantization = True
requires_calibration = False
- required_packages = ["accelerate"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
- self.quantization_config = quantization_config
def validate_environment(self, *args, **kwargs):
- if not is_torch_available():
- raise ImportError(
- "Using fp8 quantization requires torch >= 2.1.0"
- "Please install the latest version of torch ( pip install --upgrade torch )"
- )
-
if not is_accelerate_available():
raise ImportError("Loading an FP8 quantized model requires accelerate (`pip install accelerate`)")
- if (not (torch.cuda.is_available() or is_torch_xpu_available())) and not self.quantization_config.dequantize:
- if self.pre_quantized:
+ if self.quantization_config.dequantize:
+ return
+
+ if not torch.cuda.is_available() and not is_torch_xpu_available():
+ if self.pre_quantized and not self.quantization_config.dequantize:
logger.warning_once(
"Using FP8 quantized models requires a GPU or XPU, we will default to dequantizing the model to bf16 since no GPU or XPU is available"
)
@@ -64,11 +58,12 @@ def validate_environment(self, *args, **kwargs):
"your model on a GPU or XPU device in order to run your model. To remove this warning, "
"pass device_map = 'cuda' or 'xpu'. "
)
- elif device_map is not None:
+ elif isinstance(device_map, dict):
if (
not self.pre_quantized
- and isinstance(device_map, dict)
- and ("cpu" in device_map.values() or "disk" in device_map.values())
+ and len(device_map) > 1
+ and "cpu" in device_map.values()
+ or "disk" in device_map.values()
):
raise ValueError(
"You are attempting to load an FP8 model with a device_map that contains a cpu/disk device."
@@ -76,76 +71,6 @@ def validate_environment(self, *args, **kwargs):
"Please use a quantized checkpoint or remove the cpu/disk device from the device_map."
)
- def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
- if dtype is None:
- logger.info("Setting dtype to torch.float32 as no dtype was specified in from_pretrained")
- dtype = torch.float32
- return dtype
-
- # TODO: make this into a `ConversionType` ops -> potentially requires all weights on all ranks
- # depending on the layer type (moe -> no if ep)
- def create_quantized_param(
- self,
- model: "PreTrainedModel",
- param_value: "torch.Tensor",
- param_name: str,
- target_device: "torch.device",
- **kwargs,
- ):
- from ..integrations.finegrained_fp8 import FP8Linear
- from ..modeling_utils import _load_parameter_into_model
-
- # Sanity checks
- module, tensor_name = get_module_from_name(model, param_name)
- if isinstance(module, FP8Linear):
- if self.pre_quantized or tensor_name == "bias":
- if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
- raise ValueError("Expect quantized weights but got an unquantized weight")
- else:
- return
- # if tensor_name == "weight_scale_inv":
- # raise ValueError("Expect unquantized weights but got a quantized weight_scale")
-
- param_value = param_value.to(target_device)
-
- # Get FP8 min/max values
- fp8_min = torch.finfo(torch.float8_e4m3fn).min
- fp8_max = torch.finfo(torch.float8_e4m3fn).max
-
- block_size_m, block_size_n = self.quantization_config.weight_block_size
-
- rows, cols = param_value.shape[-2:]
-
- if rows % block_size_m != 0 or cols % block_size_n != 0:
- raise ValueError(
- f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})"
- )
- param_value_orig_shape = param_value.shape
-
- param_value = param_value.reshape(
- -1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n
- ).permute(0, 1, 3, 2, 4)
-
- # Calculate scaling factor for each block
- max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
- scale = fp8_max / max_abs
- scale_orig_shape = scale.shape
- scale = scale.unsqueeze(-1).unsqueeze(-1)
-
- # Quantize the weights
- quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
-
- quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
- # Reshape back to matrix shape
- quantized_param = quantized_param.reshape(param_value_orig_shape)
-
- # Reshape scale to match the number of blocks
- scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
-
- # Load into the model
- _load_parameter_into_model(model, param_name, quantized_param)
- _load_parameter_into_model(model, param_name.rsplit(".", 1)[0] + ".weight_scale_inv", scale)
-
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
from ..integrations.finegrained_fp8 import FP8Expert, FP8Linear
@@ -176,27 +101,6 @@ def _process_model_before_weight_loading(
pre_quantized=self.pre_quantized,
)
- model.config.quantization_config = self.quantization_config
-
- def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
- if self.pre_quantized and self.quantization_config.dequantize:
- self.remove_quantization_config(model)
-
- def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
- from ..integrations import FP8Linear
-
- not_missing_keys = []
- for name, module in model.named_modules():
- if isinstance(module, FP8Linear):
- for missing in missing_keys:
- if (
- (name in missing or name in f"{prefix}.{missing}")
- and not missing.endswith(".weight")
- and not missing.endswith(".bias")
- ):
- not_missing_keys.append(missing)
- return [k for k in missing_keys if k not in not_missing_keys]
-
# NOTE: TP is applied before quantization so this is only to add hooks.
# Quantization is incompatible with DTensors, so we have to anyway have
# gathers! But it should be model independant -> figure out where to put
@@ -226,7 +130,7 @@ def update_tp_plan(self, config):
return config
- def is_serializable(self, safe_serialization=None):
+ def is_serializable(self, **kwargs):
return True
@property
diff --git a/src/transformers/quantizers/quantizer_fp_quant.py b/src/transformers/quantizers/quantizer_fp_quant.py
index b5c9f2c8179f..f9d66986a2b4 100644
--- a/src/transformers/quantizers/quantizer_fp_quant.py
+++ b/src/transformers/quantizers/quantizer_fp_quant.py
@@ -36,13 +36,10 @@ class FPQuantHfQuantizer(HfQuantizer):
"""
requires_calibration = False
- requires_parameters_quantization = True
is_qat_trainable = True
- required_packages = ["fp_quant"]
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
- self.quantization_config = quantization_config
def validate_environment(self, device_map, **kwargs):
if not torch.cuda.is_available() and not is_torch_xpu_available():
@@ -68,15 +65,17 @@ def validate_environment(self, device_map, **kwargs):
"You are attempting to load a FPQuant model without setting device_map."
" Please set device_map comprised of 'cuda' devices."
)
- elif (
- isinstance(device_map, dict)
- and ("cpu" in device_map.values() or "disk" in device_map.values())
- and not self.quantization_config.pseudoquantization
- ):
- raise ValueError(
- "You are attempting to load a FPQuant model with a device_map that contains a CPU or disk device."
- " This is not supported. Please remove the CPU or disk device from the device_map."
- )
+ elif isinstance(device_map, dict):
+ if (
+ not self.quantization_config.pseudoquantization
+ and len(device_map) > 1
+ and "cpu" in device_map.values()
+ or "disk" in device_map.values()
+ ):
+ raise ValueError(
+ "You are attempting to load a FPQuant model with a device_map that contains a CPU or disk device."
+ " This is not supported. Please remove the CPU or disk device from the device_map."
+ )
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
if dtype is None:
@@ -84,50 +83,17 @@ def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
dtype = torch.bfloat16
elif dtype != torch.bfloat16:
raise ValueError(f"Invalid `dtype` {dtype}. fp_quant quantization only supports `dtype=torch.bfloat16`.")
-
return dtype
- def create_quantized_param(
- self,
- model: "PreTrainedModel",
- param_value: "torch.Tensor",
- param_name: str,
- target_device: "torch.device",
- **kwargs,
- ):
- module, _ = get_module_from_name(model, param_name)
-
- if target_device == "cpu" and param_name.endswith("weight"):
- # Works agains hard-coded missing key dispatch to CPU
- return
-
- # The module holds either:
- # * `weight` when `store_master_weights=True`
- # * `qweight` and `scales` when `store_master_weights=False` and `pseudoquantization=False`
- # * `dqweight` when `store_master_weights=False` and `pseudoquantization=True`
-
- if param_name.endswith(".qweight"):
- # Loading a real quantized checkpoint without master weights
- module.qweight = torch.nn.Parameter(
- param_value.to(target_device),
- requires_grad=False,
- )
- module.weight = None
- module.dqweight = None
- return
-
- if param_name.endswith(".dqweight"):
- # Loading a pseudo-quantized checkpoint without master weights
- module.dqweight = torch.nn.Parameter(param_value.to(target_device))
- module.weight = None
- module.qweight = None
- module.scales = None
- return
-
- # Loading master weights or an unquantized checkpoint
- module.weight = torch.nn.Parameter(param_value.to(target_device))
- # Let pre-forward handle the quantization and set None where necessary
- module.pre_forward()
+ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
+ from fp_quant import FPQuantLinear
+
+ module, tensor_name = get_module_from_name(model, param_name)
+ if isinstance(module, FPQuantLinear) and tensor_name in ["weight", "qweight", "dqweight"]:
+ # Only quantize weights of FPQuantLinear modules that are not already quantized
+ return True
+ else:
+ return False
def _process_model_before_weight_loading(
self,
@@ -142,20 +108,6 @@ def _process_model_before_weight_loading(
model,
fp_quant_linear_config=adapt_fp_quant_config(self.quantization_config),
)
- model.config.quantization_config = self.quantization_config
-
- def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
- from fp_quant import FPQuantLinear
-
- fp_quant_names = {name for name, module in model.named_modules() if isinstance(module, FPQuantLinear)}
-
- def should_exclude(key: str) -> bool:
- if key.endswith(".weight") or key.endswith(".bias"):
- return False
- full_key = f"{prefix}.{key}"
- return any(name in key or name in full_key for name in fp_quant_names)
-
- return [key for key in missing_keys if not should_exclude(key)]
@property
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
@@ -166,15 +118,5 @@ def is_trainable(self, model: Optional["PreTrainedModel"] = None):
)
return trainable
- def is_serializable(self, safe_serialization=None):
+ def is_serializable(self, **kwargs):
return True
-
- def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
- from fp_quant import FPQuantLinear
-
- module, tensor_name = get_module_from_name(model, param_name)
- if isinstance(module, FPQuantLinear) and tensor_name in ["weight", "qweight", "dqweight"]:
- # Only quantize weights of FPQuantLinear modules that are not already quantized
- return True
- else:
- return False
diff --git a/src/transformers/quantizers/quantizer_gptq.py b/src/transformers/quantizers/quantizer_gptq.py
index f12ad4ca7e94..8b828bc7ce86 100644
--- a/src/transformers/quantizers/quantizer_gptq.py
+++ b/src/transformers/quantizers/quantizer_gptq.py
@@ -39,8 +39,6 @@ class GptqHfQuantizer(HfQuantizer):
"""
requires_calibration = False
- required_packages = ["optimum", "auto_gptq", "gptqmodel"]
- optimum_quantizer = None
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
@@ -120,5 +118,5 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs
def is_trainable(self) -> bool:
return True
- def is_serializable(self, safe_serialization=None):
+ def is_serializable(self, **kwargs):
return True
diff --git a/src/transformers/quantizers/quantizer_higgs.py b/src/transformers/quantizers/quantizer_higgs.py
index f780bd5e52ab..d706e2c9f1ce 100644
--- a/src/transformers/quantizers/quantizer_higgs.py
+++ b/src/transformers/quantizers/quantizer_higgs.py
@@ -37,12 +37,9 @@ class HiggsHfQuantizer(HfQuantizer):
"""
requires_calibration = False
- requires_parameters_quantization = True
- required_packages = ["flute-kernel", "fast_hadamard_transform"]
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
- self.quantization_config = quantization_config
def validate_environment(self, device_map, **kwargs):
if not torch.cuda.is_available():
@@ -64,11 +61,12 @@ def validate_environment(self, device_map, **kwargs):
"You are attempting to load a HIGGS model without setting device_map."
" Please set device_map comprised of 'cuda' devices."
)
- elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
- raise ValueError(
- "You are attempting to load a HIGGS model with a device_map that contains a CPU or disk device."
- " This is not supported. Please remove the CPU or disk device from the device_map."
- )
+ elif isinstance(device_map, dict):
+ if "cpu" in device_map.values() or "disk" in device_map.values():
+ raise ValueError(
+ "You are attempting to load a HIGGS model with a device_map that contains a CPU or disk device."
+ " This is not supported. Please remove the CPU or disk device from the device_map."
+ )
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
if dtype is None:
@@ -81,37 +79,39 @@ def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
return dtype
- def create_quantized_param(
- self,
- model: "PreTrainedModel",
- param_value: "torch.Tensor",
- param_name: str,
- target_device: "torch.device",
- **kwargs,
- ):
- from ..integrations import quantize_with_higgs
-
- flute_dict = quantize_with_higgs(
- param_value.to(target_device),
- self.quantization_config.bits,
- self.quantization_config.p,
- self.quantization_config.group_size,
- self.quantization_config.hadamard_size,
- )
- del param_value
-
- module, _ = get_module_from_name(model, param_name)
- module_name = ".".join(param_name.split(".")[:-1])
- for key, value in flute_dict.items():
- if key in module._parameters:
- module._parameters[key] = torch.nn.Parameter(value, requires_grad=False)
- elif key in module._buffers:
- module._buffers[key] = torch.nn.Buffer(value)
- elif key == "tune_metadata":
- module.tune_metadata = value
- self.quantization_config.tune_metadata[module_name] = value.to_dict()
- else:
- raise ValueError(f"Unexpected key {key} in module {module}")
+ # TODO: to remove
+ # Kept here in case we see some interest in adding support for it
+ # def create_quantized_param(
+ # self,
+ # model: "PreTrainedModel",
+ # param_value: "torch.Tensor",
+ # param_name: str,
+ # target_device: "torch.device",
+ # **kwargs,
+ # ):
+ # from ..integrations import quantize_with_higgs
+
+ # flute_dict = quantize_with_higgs(
+ # param_value.to(target_device),
+ # self.quantization_config.bits,
+ # self.quantization_config.p,
+ # self.quantization_config.group_size,
+ # self.quantization_config.hadamard_size,
+ # )
+ # del param_value
+
+ # module, _ = get_module_from_name(model, param_name)
+ # module_name = ".".join(param_name.split(".")[:-1])
+ # for key, value in flute_dict.items():
+ # if key in module._parameters:
+ # module._parameters[key] = torch.nn.Parameter(value, requires_grad=False)
+ # elif key in module._buffers:
+ # module._buffers[key] = torch.nn.Buffer(value)
+ # elif key == "tune_metadata":
+ # module.tune_metadata = value
+ # self.quantization_config.tune_metadata[module_name] = value.to_dict()
+ # else:
+ # raise ValueError(f"Unexpected key {key} in module {module}")
def _process_model_before_weight_loading(
self,
@@ -130,7 +130,6 @@ def _process_model_before_weight_loading(
quantization_config=self.quantization_config,
modules_to_not_convert=self.modules_to_not_convert,
)
- model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
from flute.tune import TuneMetaData, maybe_tune_and_repack
@@ -157,19 +156,6 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs
)
self.quantization_config.tune_metadata[name] = module.tune_metadata.to_dict()
- def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
- from ..integrations import HiggsLinear
-
- higgs_names = {name for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
-
- def should_update(key: str) -> bool:
- if key.endswith(".weight") or key.endswith(".bias"):
- return False
- full_key = f"{prefix}.{key}"
- return any(name in key or name in full_key for name in higgs_names)
-
- return [key for key in missing_keys if not should_update(key)]
-
@property
def is_trainable(self) -> bool:
return False
diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py
index 94907c3b48fc..2dfbe20b35a4 100755
--- a/src/transformers/quantizers/quantizer_hqq.py
+++ b/src/transformers/quantizers/quantizer_hqq.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from collections import defaultdict
from typing import TYPE_CHECKING
from ..integrations import prepare_for_hqq_linear
@@ -49,10 +48,7 @@ class HqqHfQuantizer(HfQuantizer):
nn.Linear modules are first tagged with quant_config in _process_model_before_weight_loading().
"""
- use_keep_in_fp32_modules = False
- requires_parameters_quantization = True
requires_calibration = False
- required_packages = ["hqq"]
def __init__(self, quantization_config, **kwargs):
if not is_hqq_available():
@@ -83,73 +79,67 @@ def validate_environment(self, *args, **kwargs):
else:
self.using_multi_gpu = len(set(device_map.values())) > 1
- def update_missing_keys(
- self, model: "PreTrainedModel", missing_keys: list[str], prefix: str, **kwargs
- ) -> list[str]:
- if self.pre_quantized:
- return [key for key in missing_keys if ("weight" not in key)]
- else:
- return missing_keys
-
- # Adds missing keys for HQQLinear modules that are loaded but the model with initialized with torch.nn.Linear
- def update_expected_keys(
- self, model: "PreTrainedModel", expected_keys: list[str], loaded_keys: list[str]
- ) -> list[str]:
- if not self.pre_quantized:
- return expected_keys
-
- # Collects all quantizable (linear) layers
- def _find_hqq_quantizable_layers(model, layers):
- for name, module in model.named_children():
- if isinstance(module, (torch.nn.Linear)):
- layers.add(module.name)
- _find_hqq_quantizable_layers(module, layers)
-
- new_keys = set(expected_keys)
-
- # Name modules
- for name, module in model.named_modules():
- module.name = name
-
- # valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
- _valid_modules = set()
- _find_hqq_quantizable_layers(model, _valid_modules)
-
- # Remove skipped modules
- _skipped_modules = set()
- for _module in _valid_modules:
- for _skip_module in model.config.quantization_config["skip_modules"]:
- if _skip_module in _module:
- _skipped_modules.add(_module)
- _valid_modules -= _skipped_modules
-
- # Append new expected layers based on _ref_keys
- _ref_keys = HQQLinear(
- linear_layer=None,
- quant_config=None,
- compute_dtype=torch.float16,
- device="cpu",
- del_orig=False,
- ).state_dict_keys() - {"bias"}
-
- # Clean-up
- _rm_keys = set()
- for key in new_keys:
- if any(_module in key for _module in _valid_modules):
- _rm_keys.add(key)
- new_keys -= _rm_keys
- # At this point, new_keys contains all the keys of the layers that are NOT HQQLinear or torch.nn.Linear
-
- # Re-populate Linear/HQQLinear
- for _module in _valid_modules:
- if _module + ".weight" in loaded_keys:
- new_keys.add(_module + ".weight")
- else:
- new_keys.update({_module + "." + _ref_key for _ref_key in _ref_keys})
- if _module + ".bias" in loaded_keys:
- new_keys.add(_module + ".bias")
-
- return list(new_keys)
+ # TODO: to remove
+ # Kept here in case we see some interest in adding support for it
+ # # Adds missing keys for HQQLinear modules that are loaded but the model with initialized with torch.nn.Linear
+ # def update_expected_keys(
+ # self, model: "PreTrainedModel", expected_keys: list[str], loaded_keys: list[str]
+ # ) -> list[str]:
+ # if not self.pre_quantized:
+ # return expected_keys
+
+ # # Collects all quantizable (linear) layers
+ # def _find_hqq_quantizable_layers(model, layers):
+ # for name, module in model.named_children():
+ # if isinstance(module, (torch.nn.Linear)):
+ # layers.add(module.name)
+ # _find_hqq_quantizable_layers(module, layers)
+
+ # new_keys = set(expected_keys)
+
+ # # Name modules
+ # for name, module in model.named_modules():
+ # module.name = name
+
+ # # valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
+ # _valid_modules = set()
+ # _find_hqq_quantizable_layers(model, _valid_modules)
+
+ # # Remove skipped modules
+ # _skipped_modules = set()
+ # for _module in _valid_modules:
+ # for _skip_module in model.config.quantization_config["skip_modules"]:
+ # if _skip_module in _module:
+ # _skipped_modules.add(_module)
+ # _valid_modules -= _skipped_modules
+
+ # # Append new expected layers based on _ref_keys
+ # _ref_keys = HQQLinear(
+ # linear_layer=None,
+ # quant_config=None,
+ # compute_dtype=torch.float16,
+ # device="cpu",
+ # del_orig=False,
+ # ).state_dict_keys() - {"bias"}
+
+ # # Clean-up
+ # _rm_keys = set()
+ # for key in new_keys:
+ # if any(_module in key for _module in _valid_modules):
+ # _rm_keys.add(key)
+ # new_keys -= _rm_keys
+ # # At this point, new_keys contains all the keys of the layers that are NOT HQQLinear or torch.nn.Linear
+
+ # # Re-populate Linear/HQQLinear
+ # for _module in _valid_modules:
+ # if _module + ".weight" in loaded_keys:
+ # new_keys.add(_module + ".weight")
+ # else:
+ # new_keys.update({_module + "." + _ref_key for _ref_key in _ref_keys})
+ # if _module + ".bias" in loaded_keys:
+ # new_keys.add(_module + ".bias")
+
+ # return list(new_keys)
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
module, _ = get_module_from_name(model, param_name)
@@ -157,87 +147,88 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **
# `create_quantized_param`, even when `self.is_quantized == True`
return isinstance(module, torch.nn.Linear)
- def create_quantized_param(
- self,
- model: "PreTrainedModel",
- param_value: "torch.Tensor",
- param_name: str,
- target_device: "torch.device",
- **kwargs,
- ):
- module, tensor_name = get_module_from_name(model, param_name)
- module_name = param_name.rsplit(".", 1)[0]
- parent_module, node = get_module_from_name(model, module_name)
-
- quant_config = model.config.quantization_config["quant_config"]
- skip_modules = model.config.quantization_config["skip_modules"]
-
- # In this case we do not quantize this layer (it's explicitly skipped) -> simply load param
- if any(skip_module in module.name for skip_module in skip_modules):
- module.load_state_dict(
- {tensor_name: param_value.to(device=target_device, dtype=self.dtype)}, strict=False, assign=True
- )
- return
-
- # We need this hack as the model is not pre-prepared as an empty skeleton on meta device
- if self.pre_quantized:
- # Save them for later
- if not hasattr(self, "hqq_params"):
- self.hqq_params = defaultdict(dict)
- self.hqq_params[module_name].update({tensor_name: param_value})
- hqq_params = self.hqq_params[module_name]
-
- # If they are all present and saved, make it a HQQLinear layer! (we cannot do it param after param because
- # hqq does not support it...)
- if all(k in hqq_params for k in self.hqq_keys) and ("bias" in hqq_params or module.bias is None):
- hqq_layer = HQQLinear(
- linear_layer=None,
- quant_config=None,
- compute_dtype=self.dtype,
- device=target_device,
- del_orig=False,
- )
- hqq_layer.load_state_dict(hqq_params)
-
- if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
- hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
- if self.using_multi_gpu:
- hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
-
- setattr(parent_module, node, hqq_layer)
- del self.hqq_params[module_name], module
- return
-
- # Load param in the module (without caring about device or dtype, it will be changed later)
- module.load_state_dict({tensor_name: param_value}, strict=False, assign=True)
-
- # If both the weight and bias have already been loaded, time to quantize!
- module_is_ready = module.weight.device.type != "meta" and (
- module.bias is None or module.bias.device.type != "meta"
- )
-
- if module_is_ready:
- module_tag = ".".join(module.name.split(".")[-2:])
- if "weight_quant_params" in quant_config:
- module_quant_config = quant_config
- elif module_tag in quant_config:
- module_quant_config = quant_config[module_tag]
-
- hqq_layer = HQQLinear(
- module,
- quant_config=module_quant_config,
- compute_dtype=self.dtype,
- device=target_device,
- del_orig=True,
- )
-
- if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
- hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
-
- if self.using_multi_gpu:
- hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
-
- setattr(parent_module, node, hqq_layer)
+ # TODO: to remove
+ # def create_quantized_param(
+ # self,
+ # model: "PreTrainedModel",
+ # param_value: "torch.Tensor",
+ # param_name: str,
+ # target_device: "torch.device",
+ # **kwargs,
+ # ):
+ # module, tensor_name = get_module_from_name(model, param_name)
+ # module_name = param_name.rsplit(".", 1)[0]
+ # parent_module, node = get_module_from_name(model, module_name)
+
+ # quant_config = model.config.quantization_config["quant_config"]
+ # skip_modules = model.config.quantization_config["skip_modules"]
+
+ # # In this case we do not quantize this layer (it's explicitly skipped) -> simply load param
+ # if any(skip_module in module.name for skip_module in skip_modules):
+ # module.load_state_dict(
+ # {tensor_name: param_value.to(device=target_device, dtype=self.dtype)}, strict=False, assign=True
+ # )
+ # return
+
+ # # We need this hack as the model is not pre-prepared as an empty skeleton on meta device
+ # if self.pre_quantized:
+ # # Save them for later
+ # if not hasattr(self, "hqq_params"):
+ # self.hqq_params = defaultdict(dict)
+ # self.hqq_params[module_name].update({tensor_name: param_value})
+ # hqq_params = self.hqq_params[module_name]
+
+ # # If they are all present and saved, make it a HQQLinear layer! (we cannot do it param after param because
+ # # hqq does not support it...)
+ # if all(k in hqq_params for k in self.hqq_keys) and ("bias" in hqq_params or module.bias is None):
+ # hqq_layer = HQQLinear(
+ # linear_layer=None,
+ # quant_config=None,
+ # compute_dtype=self.dtype,
+ # device=target_device,
+ # del_orig=False,
+ # )
+ # hqq_layer.load_state_dict(hqq_params)
+
+ # if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
+ # hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
+ # if self.using_multi_gpu:
+ # hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
+
+ # setattr(parent_module, node, hqq_layer)
+ # del self.hqq_params[module_name], module
+ # return
+
+ # # Load param in the module (without caring about device or dtype, it will be changed later)
+ # module.load_state_dict({tensor_name: param_value}, strict=False, assign=True)
+
+ # # If both the weight and bias have already been loaded, time to quantize!
+ # module_is_ready = module.weight.device.type != "meta" and (
+ # module.bias is None or module.bias.device.type != "meta"
+ # )
+
+ # if module_is_ready:
+ # module_tag = ".".join(module.name.split(".")[-2:])
+ # if "weight_quant_params" in quant_config:
+ # module_quant_config = quant_config
+ # elif module_tag in quant_config:
+ # module_quant_config = quant_config[module_tag]
+
+ # hqq_layer = HQQLinear(
+ # module,
+ # quant_config=module_quant_config,
+ # compute_dtype=self.dtype,
+ # device=target_device,
+ # del_orig=True,
+ # )
+
+ # if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
+ # hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
+
+ # if self.using_multi_gpu:
+ # hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
+
+ # setattr(parent_module, node, hqq_layer)
def _patch_layer_for_multigpu(self, hqq_layer):
def forward_with_device(self, x):
diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py
index e5e70af1ab6b..08b8538b8230 100644
--- a/src/transformers/quantizers/quantizer_mxfp4.py
+++ b/src/transformers/quantizers/quantizer_mxfp4.py
@@ -43,14 +43,10 @@ class Mxfp4HfQuantizer(HfQuantizer):
FP4 quantization using fbgemm kernels
"""
- requires_parameters_quantization = True
requires_calibration = False
- required_packages = ["accelerate"]
-
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
- self.quantization_config = quantization_config
self.triton_kernels_hub = None
def _lazy_import_kernels(self):
@@ -74,7 +70,7 @@ def validate_environment(self, *args, **kwargs):
if self.quantization_config.dequantize:
return
- if not (torch.cuda.is_available() or torch.xpu.is_available()):
+ if not torch.cuda.is_available() and not torch.xpu.is_available():
if self.pre_quantized:
logger.warning_once(
"Using MXFP4 quantized models requires a GPU, we will default to dequantizing the model to bf16"
@@ -131,12 +127,8 @@ def validate_environment(self, *args, **kwargs):
"You have loaded an FP4 model on CPU and have a CUDA/XPU device available, make sure to set "
"your model on a GPU/XPU device in order to run your model. To remove this warning, pass device_map = 'cuda' or device_map = 'xpu'. "
)
- elif device_map is not None:
- if (
- not self.pre_quantized
- and isinstance(device_map, dict)
- and ("cpu" in device_map.values() or "disk" in device_map.values())
- ):
+ elif isinstance(device_map, dict):
+ if not self.pre_quantized and ("cpu" in device_map.values() or "disk" in device_map.values()):
raise ValueError(
"You are attempting to load an FP4 model with a device_map that contains a CPU or disk device."
"This is not supported when the model is quantized on the fly. "
@@ -157,146 +149,21 @@ def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
from ..integrations import Mxfp4GptOssExperts
- from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts
- if self.pre_quantized:
- return False
- # if we are dequantizing, the model doesn't have scales, and blocks only params like gate_up_proj and down_proj so we need to handle this case differently
- if self.quantization_config.dequantize and ("blocks" in param_name or "scales" in param_name):
- module, tensor_name = get_module_from_name(model, param_name[: -len("_blocks")])
- else:
- module, tensor_name = get_module_from_name(model, param_name)
- if isinstance(module, Mxfp4GptOssExperts) or (
- isinstance(module, GptOssExperts) and self.quantization_config.dequantize
- ):
+ module, tensor_name = get_module_from_name(model, param_name)
+ if isinstance(module, Mxfp4GptOssExperts):
if tensor_name in ["down_proj_bias", "gate_up_proj_bias"]:
return False
return True
return False
- def create_quantized_param(
- self,
- model: "PreTrainedModel",
- param_value: "torch.Tensor",
- param_name: str,
- target_device: "torch.device",
- **kwargs,
- ):
- from ..integrations import (
- Mxfp4GptOssExperts,
- dequantize,
- load_and_swizzle_mxfp4,
- quantize_to_mxfp4,
- swizzle_mxfp4,
- )
- from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts
-
- if not self.pre_quantized:
- triton_kernels_hub = self._lazy_import_kernels()
- module, _ = get_module_from_name(model, param_name)
- with torch.device(target_device):
- if isinstance(module, Mxfp4GptOssExperts):
- triton_weight_tensor, weight_scale = quantize_to_mxfp4(param_value, triton_kernels_hub)
- PrecisionConfig, FlexCtx, InFlexData = (
- triton_kernels_hub.matmul_ogs.PrecisionConfig,
- triton_kernels_hub.matmul_ogs.FlexCtx,
- triton_kernels_hub.matmul_ogs.InFlexData,
- )
- triton_weight_tensor, weight_scale = swizzle_mxfp4(
- triton_weight_tensor, weight_scale, triton_kernels_hub
- )
-
- proj = "gate_up_proj" if "gate_up_proj" in param_name else "down_proj"
- setattr(module, proj, triton_weight_tensor)
- setattr(
- module,
- f"{proj}_precision_config",
- PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
- )
-
- delattr(module, f"{proj}_blocks")
- delattr(module, f"{proj}_scales")
-
- # The params going here are either gate_up_proj_blocks, or down_proj_blocks, or gate_up_proj_scales, or down_proj_scales
- else:
- # This is when loading a quantized model (blocks and scales exist)
- empty_param = kwargs.get("empty_param")
- casting_dtype = kwargs.get("casting_dtype")
- to_contiguous = kwargs.get("to_contiguous")
- rank = kwargs.get("rank")
- device_mesh = kwargs.get("device_mesh")
- if ("blocks" in param_name or "scales" in param_name) and self.quantization_config.dequantize:
- # blocks and scales have the same length that's why this works for both
- module, _ = get_module_from_name(model, param_name[: -len("_blocks")])
- else:
- module, _ = get_module_from_name(model, param_name)
-
- shard_kwargs = {
- "empty_param": empty_param,
- "casting_dtype": casting_dtype,
- "to_contiguous": to_contiguous,
- "rank": rank,
- "device_mesh": device_mesh,
- "model": model,
- }
-
- if isinstance(module, Mxfp4GptOssExperts) or (
- isinstance(module, GptOssExperts) and self.quantization_config.dequantize
- ):
- if self.quantization_config.dequantize:
- # dq_param_name is the name of the parameter without the blocks or scales suffix, it's used in this case since we don't switch linears
- # so we only have the original param name
- dq_param_name = param_name[: -len("_blocks")]
- dequantize(module, param_name, param_value, target_device, dq_param_name, **shard_kwargs)
- else:
- load_and_swizzle_mxfp4(
- module,
- param_name,
- param_value,
- target_device,
- self._lazy_import_kernels(),
- **shard_kwargs,
- )
-
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
- # we are not really dequantizing, we are just removing everything related to quantization here
- if self.quantization_config.dequantize:
- self.remove_quantization_config(model)
# clean cache due to triton ops
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.xpu.is_available():
torch.xpu.empty_cache()
- def update_expected_keys(self, model: "PreTrainedModel", expected_keys: list[str], checkpoint_keys: list[str]):
- # Replace expected_keys for experts' gate_up_proj and down_proj with their _blocks and _scales variants
- new_expected_keys = []
- for key in expected_keys:
- if key.endswith(".mlp.experts.gate_up_proj"):
- base = key[: -len("gate_up_proj")]
- new_expected_keys.append(base + "gate_up_proj_blocks")
- new_expected_keys.append(base + "gate_up_proj_scales")
- elif key.endswith(".mlp.experts.down_proj"):
- base = key[: -len("down_proj")]
- new_expected_keys.append(base + "down_proj_blocks")
- new_expected_keys.append(base + "down_proj_scales")
- elif not self.pre_quantized:
- # in this case, we are quantizing the model so we need to update the keys as we changed the layers
- if key.endswith(".mlp.experts.down_proj_blocks"):
- base = key[: -len("down_proj_blocks")]
- new_expected_keys.append(base + "down_proj")
- elif key.endswith(".mlp.experts.gate_up_proj_blocks"):
- base = key[: -len("gate_up_proj_blocks")]
- new_expected_keys.append(base + "gate_up_proj")
- elif key.endswith("scales"):
- # we remove it the scales as the checkpoint don't contain them
- continue
- else:
- new_expected_keys.append(key)
- else:
- new_expected_keys.append(key)
- return new_expected_keys
-
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
@@ -306,10 +173,6 @@ def _process_model_before_weight_loading(
):
from ..integrations import replace_with_mxfp4_linear
- self.modules_to_not_convert = self.get_modules_to_not_convert(
- model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
- )
-
# if we are using kernels, we can't use the quantized model, since the forward pass is different and needs special handling
if use_kernels:
logger.warning_once(
@@ -318,27 +181,14 @@ def _process_model_before_weight_loading(
)
self.quantization_config.dequantize = True
+ self.modules_to_not_convert = self.get_modules_to_not_convert(
+ model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
+ )
+
model = replace_with_mxfp4_linear(
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
)
- model.config.quantization_config = self.quantization_config
-
- def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
- from ..integrations import Mxfp4GptOssExperts
-
- not_missing_keys = []
- for name, module in model.named_modules():
- if isinstance(module, Mxfp4GptOssExperts):
- for missing in missing_keys:
- if (
- (name in missing or name in f"{prefix}.{missing}")
- and not missing.endswith(".weight")
- and not missing.endswith(".bias")
- ):
- not_missing_keys.append(missing)
- return [k for k in missing_keys if k not in not_missing_keys]
-
def update_tp_plan(self, config):
if "GptOssConfig" in config.__class__.__name__:
if getattr(config, "base_model_tp_plan", None) is not None:
@@ -378,7 +228,7 @@ def get_param_name(self, param_name: str) -> str:
return param_name.replace("down_proj", "down_proj_blocks")
return param_name
- def get_state_dict_and_metadata(self, model, safe_serialization: bool = False):
+ def get_state_dict_and_metadata(self, model, **kwargs):
from ..integrations import Mxfp4GptOssExperts
state_dict = model.state_dict()
@@ -417,7 +267,7 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool = False):
metadata = {}
return state_dict, metadata
- def is_serializable(self, safe_serialization=None):
+ def is_serializable(self, **kwargs):
return True
@property
diff --git a/src/transformers/quantizers/quantizer_quanto.py b/src/transformers/quantizers/quantizer_quanto.py
index 25af8d2874c6..6525efa0ff52 100644
--- a/src/transformers/quantizers/quantizer_quanto.py
+++ b/src/transformers/quantizers/quantizer_quanto.py
@@ -40,8 +40,6 @@ class QuantoHfQuantizer(HfQuantizer):
Quantizer for the quanto library
"""
- required_packages = ["quanto", "accelerate"]
- requires_parameters_quantization = True
requires_calibration = False
def __init__(self, quantization_config: QuantoConfig, **kwargs):
@@ -57,12 +55,8 @@ def validate_environment(self, *args, **kwargs):
"Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)"
)
device_map = kwargs.get("device_map")
- if device_map is not None:
- if (
- isinstance(device_map, dict)
- and len(device_map) >= 2
- and ("cpu" in device_map.values() or "disk" in device_map.values())
- ):
+ if isinstance(device_map, dict):
+ if len(device_map) > 1 and "cpu" in device_map.values() or "disk" in device_map.values():
raise ValueError(
"You are attempting to load an model with a device_map that contains a CPU or disk device."
"This is not supported with quanto when the model is quantized on the fly. "
@@ -113,7 +107,6 @@ def _process_model_before_weight_loading(
model = replace_with_quanto_layers(
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
)
- model.config.quantization_config = self.quantization_config
@property
def is_trainable(self) -> bool:
diff --git a/src/transformers/quantizers/quantizer_quark.py b/src/transformers/quantizers/quantizer_quark.py
index c11548b1416e..c8c8eb6c7f26 100644
--- a/src/transformers/quantizers/quantizer_quark.py
+++ b/src/transformers/quantizers/quantizer_quark.py
@@ -45,12 +45,6 @@ class QuarkHfQuantizer(HfQuantizer):
"""
requires_calibration = True # On-the-fly quantization with quark is not supported for now.
- required_packages = ["quark"]
-
- # Checkpoints are expected to be already quantized when loading a quark model. However, as some keys from
- # the checkpoint might mismatch the model parameters keys, we use the `create_quantized_param` method
- # to load the checkpoints, remapping the keys.
- requires_parameters_quantization = True
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
@@ -78,17 +72,7 @@ def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwarg
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
return True
- def create_quantized_param(self, model, param, param_name, param_device, **kwargs):
- from ..modeling_utils import _load_parameter_into_model
-
- postfix = param_name.split(".")[-1]
-
- if postfix in CHECKPOINT_KEYS:
- param_name = param_name.replace(postfix, CHECKPOINT_KEYS[postfix])
-
- _load_parameter_into_model(model, param_name, param.to(param_device))
-
- def is_serializable(self, safe_serialization=None):
+ def is_serializable(self, **kwargs):
return False
@property
diff --git a/src/transformers/quantizers/quantizer_spqr.py b/src/transformers/quantizers/quantizer_spqr.py
index ce549b732ca2..710714160772 100644
--- a/src/transformers/quantizers/quantizer_spqr.py
+++ b/src/transformers/quantizers/quantizer_spqr.py
@@ -39,7 +39,6 @@ class SpQRHfQuantizer(HfQuantizer):
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
- self.quantization_config = quantization_config
def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available():
@@ -71,17 +70,15 @@ def _process_model_before_weight_loading(
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)
-
replace_with_spqr_linear(
model,
quantization_config=self.quantization_config,
modules_to_not_convert=self.modules_to_not_convert,
)
- model.config.quantization_config = self.quantization_config
@property
def is_trainable(self):
return False
- def is_serializable(self, safe_serialization=None):
+ def is_serializable(self, **kwargs):
return True
diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py
index e7919b7f81b7..e5a6de6478c5 100644
--- a/src/transformers/quantizers/quantizer_torchao.py
+++ b/src/transformers/quantizers/quantizer_torchao.py
@@ -13,8 +13,6 @@
# limitations under the License.
import importlib
import re
-import types
-from collections import defaultdict
from typing import TYPE_CHECKING
from packaging import version
@@ -37,17 +35,12 @@
if is_torch_available():
import torch
- import torch.nn as nn
if is_torchao_available():
- import torchao
-
- if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"):
+ if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.15.0"):
from torchao.prototype.safetensors.safetensors_support import (
flatten_tensor_state_dict,
- unflatten_tensor_state_dict,
)
- from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao
logger = logging.get_logger(__name__)
@@ -88,11 +81,6 @@ def _linear_extra_repr(self):
if is_torchao_available():
- SUPPORTED_SAFE_SERIALIZATION_CONFIGS = [
- torchao.quantization.Float8WeightOnlyConfig,
- torchao.quantization.Float8DynamicActivationFloat8WeightConfig,
- ]
-
TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))
@@ -101,9 +89,7 @@ class TorchAoHfQuantizer(HfQuantizer):
Quantizer for torchao: https://github.com/pytorch/ao/
"""
- requires_parameters_quantization = True
requires_calibration = False
- required_packages = ["torchao"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
@@ -171,12 +157,12 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool | None = F
If the model is safe serializable, we flatten the state dict of tensor subclasses so that it is compatible with
the safetensors format.
"""
- if type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and safe_serialization:
- if TORCHAO_VERSION >= version.parse("0.14.0"):
+ if safe_serialization:
+ if TORCHAO_VERSION >= version.parse("0.15.0"):
return flatten_tensor_state_dict(model.state_dict())
else:
raise RuntimeError(
- f"In order to use safetensors with torchao, please use torchao version >= 0.14.0. Current version: {TORCHAO_VERSION}"
+ f"In order to use safetensors with torchao, please use torchao version >= 0.15.0. Current version: {TORCHAO_VERSION}"
)
else:
return None, {}
@@ -237,9 +223,6 @@ def _process_model_before_weight_loading(
]
return
- def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
- return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.full_ao_keys)]
-
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
if self.pre_quantized:
return False
@@ -249,8 +232,6 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **
# check if the param_name is not in self.modules_to_not_convert
if any(key + "." in param_name or key == param_name for key in self.modules_to_not_convert):
return False
- elif any(param_name.endswith(f":{x}") for x in self.full_ao_keys):
- return True
# we only quantize the weight of nn.Linear and nn.Embedding
module, tensor_name = get_module_from_name(model, param_name)
@@ -276,148 +257,6 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **
return isinstance(module, tuple(_QUANTIZABLE)) and tensor_name == "weight"
- def create_quantized_param(
- self,
- model: "PreTrainedModel",
- param_value: "torch.Tensor",
- param_name: str,
- target_device: "torch.device",
- **kwargs,
- ):
- """
- Each nn.Linear layer that needs to be quantized is processed here.
- First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
- """
- from torchao.quantization import quantize_
-
- full_name = param_name
- # Those are the pre quantized weights
- if ":" in param_name:
- param_name = param_name.rsplit(":", 1)[0]
- module, tensor_name = get_module_from_name(model, param_name)
-
- if self.pre_quantized:
- # If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was
- # already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
- is_unsafe_serialization = ":" not in full_name
- if tensor_name == "bias" or is_unsafe_serialization:
- module._parameters[tensor_name] = torch.nn.Parameter(
- param_value.to(target_device), requires_grad=param_value.requires_grad
- )
- return
- # Sanity check for the new serialization format
- elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.metadata)):
- raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")
-
- # Save the states for later quantization when they are all gathered
- if not hasattr(self, "ao_params"):
- self.ao_params = defaultdict(dict)
- self.ao_params[param_name].update({full_name: param_value})
-
- # We are ready for quantization in this case (we retrieved all the needed keys)
- if len(self.ao_params[param_name]) == len(self.weight_ao_keys):
- new_param = unflatten_tensor_state_dict(self.ao_params[param_name], self.metadata)[param_name]
- # Set it
- module._parameters[tensor_name] = torch.nn.Parameter(
- new_param.to(target_device), requires_grad=new_param.requires_grad
- )
-
- # Free memory
- del self.ao_params[param_name]
-
- # Add repr to the module
- if isinstance(module, nn.Linear):
- module.extra_repr = types.MethodType(_linear_extra_repr, module)
- else:
- module._parameters[tensor_name] = torch.nn.Parameter(
- param_value, requires_grad=param_value.requires_grad
- ).to(target_device)
- # if we are quantizing tied parameters, to avoid tying the quantized weights
- # the correct order to do it is
- # 1. load the weight to model
- # 2. run tie_weights to populate the weights
- # 3. quantize
- input_embed = model.get_input_embeddings()
- if self.quantization_config.untie_embedding_weights and id(module) == id(input_embed):
- model.tie_weights()
- setattr(model.config.get_text_config(decoder=True), "tie_word_embeddings", False)
-
- # handle FqnToConfig, introduced in torchao 0.15.0+
- if self.quantization_config._get_ao_version() >= version.Version("0.15.0"):
- from torchao.quantization import FqnToConfig
-
- config = self.quantization_config.get_apply_tensor_subclass()
- if isinstance(config, FqnToConfig):
- module_fqn, top_level_param_name = param_name.rsplit(".", 1)
- c = None
- if param_name in config.fqn_to_config:
- assert not module_fqn.startswith("re:"), (
- "param fqn should not start with`re:`, which is used for specifying regex"
- )
- c = config.module_fqn_to_config[param_name]
- elif module_fqn in config.fqn_to_config:
- assert not module_fqn.startswith("re:"), (
- "module fqn should not start with`re:`, which is used for specifying regex"
- )
- c = config.module_fqn_to_config[module_fqn]
- # regex match module and param
- else:
- for maybe_module_fqn_pattern in config.fqn_to_config:
- # if key doesn't start with re, it is an exact fqn key, so we don't regex match
- if not maybe_module_fqn_pattern.startswith("re:"):
- continue
- # see if param matches first
- elif re.fullmatch(maybe_module_fqn_pattern[3:], param_name):
- c = config.module_fqn_to_config[maybe_module_fqn_pattern]
- break
- elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn):
- # we'll apply the config for first fully matched pattern
- c = config.module_fqn_to_config[maybe_module_fqn_pattern]
- break
- else:
- c = config.module_fqn_to_config.get("_default", None)
-
- if c is not None:
- if top_level_param_name == "weight":
- # we can apply the module config directly
- quantize_(module, c, (lambda x, fqn: True))
- else:
- # need to apply to custom param name
- custom_param_fqn_config = FqnToConfig({top_level_param_name: c})
- quantize_(module, custom_param_fqn_config, filter_fn=None)
- return
-
- # handle ModuleFqnToConfig, introduced in torchao 0.12.0+
- # TODO deprecate this when we deprecate ModuleFqnToConfig
- elif self.quantization_config._get_ao_version() >= version.Version("0.12.0"):
- from torchao.quantization import ModuleFqnToConfig
-
- config = self.quantization_config.get_apply_tensor_subclass()
- if isinstance(config, ModuleFqnToConfig):
- module_fqn, _ = param_name.rsplit(".", 1)
- c = None
- if module_fqn in config.module_fqn_to_config:
- assert not module_fqn.startswith("re:"), (
- "module fqn should not start with`re:`, which is used for specifying regex"
- )
- c = config.module_fqn_to_config[module_fqn]
- else:
- for maybe_module_fqn_pattern in config.module_fqn_to_config:
- if not maybe_module_fqn_pattern.startswith("re:"):
- continue
- elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn):
- # we'll apply the config for first fully matched pattern
- c = config.module_fqn_to_config[maybe_module_fqn_pattern]
- break
- else:
- c = config.module_fqn_to_config.get("_default", None)
- if c is not None:
- # filter_fn: not filtering out any modules
- quantize_(module, c, filter_fn=lambda x, fqn: True)
- return
-
- quantize_(module, self.quantization_config.get_apply_tensor_subclass())
-
def preprocess_model(self, model: "PreTrainedModel", config, dtype=None, checkpoint_files=None, **kwargs):
"""
Setting model attributes and/or converting model before weights loading. At this point
@@ -452,29 +291,21 @@ def _process_model_after_weight_loading(self, model, **kwargs):
def is_serializable(self, safe_serialization=None) -> bool:
if safe_serialization:
- _is_torchao_serializable = type(
- self.quantization_config.quant_type
- ) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version.parse("0.14.0")
- if not _is_torchao_serializable:
+ _is_torchao_serializable = TORCHAO_VERSION >= version.parse("0.15.0")
+ if not TORCHAO_VERSION >= version.parse("0.15.0"):
logger.warning(
- f"torchao quantized model only supports safe serialization for {SUPPORTED_SAFE_SERIALIZATION_CONFIGS}, \
- and torchao version >= 0.14.0, please set `safe_serialization` to False for \
+ f"torchao quantized model only supports safe serialization for torchao version >= 0.15.0, please set `safe_serialization` to False for \
{type(self.quantization_config.quant_type)} and {TORCHAO_VERSION}."
)
return _is_torchao_serializable
- _is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(
- "0.25.0"
- )
- if not _is_torchao_serializable:
- logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ")
if self.offload and self.quantization_config.modules_to_not_convert is None:
logger.warning(
"The model contains offloaded modules and these modules are not quantized. We don't recommend saving the model as we won't be able to reload them."
"If you want to specify modules to not quantize, please specify modules_to_not_convert in the quantization_config."
)
return False
- return _is_torchao_serializable
+ return True
def get_accelerator_warm_up_factor(self):
"""
@@ -548,15 +379,18 @@ def get_weight_conversions(self):
if self.pre_quantized:
return [
WeightConverter(
- source_patterns=["weight:qdata", "weight:scale", "weight:zero_point"],
- target_patterns="weight",
- operations=[TorchAoDeserialize(self)],
- ),
- WeightConverter(
- source_patterns=["weight:_data"],
+ # TODO: incr flexibility by generalizing the source patterns to match the format of "_weight_"
+ # note that the matching logic is greedy, so for ex, if _weight_scale is before _weight_scale_and_zero in this list, it will match _weight_scale always (this is incorrect)
+ # thus, the order of source_patterns is intentional
+ source_patterns=[
+ "_weight_qdata",
+ "_weight_scale_and_zero",
+ "_weight_scale",
+ "_weight_zero_point",
+ "_weight_act_pre_scale",
+ ],
target_patterns="weight",
operations=[TorchAoDeserialize(self)],
),
- # used for unsafe serialization
]
return []
diff --git a/src/transformers/quantizers/quantizer_vptq.py b/src/transformers/quantizers/quantizer_vptq.py
index e5e30152261c..09808b97cf65 100644
--- a/src/transformers/quantizers/quantizer_vptq.py
+++ b/src/transformers/quantizers/quantizer_vptq.py
@@ -35,11 +35,9 @@ class VptqHfQuantizer(HfQuantizer):
"""
requires_calibration = True
- required_packages = ["vptq"]
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
- self.quantization_config = quantization_config
def validate_environment(self, *args, **kwargs):
if not is_accelerate_available():
@@ -48,21 +46,15 @@ def validate_environment(self, *args, **kwargs):
if not is_vptq_available():
raise ImportError("Using `vptq` quantization requires VPTQ>=0.0.4: `pip install -U vptq`")
+ if not torch.cuda.is_available():
+ raise RuntimeError("GPU is required to run VTPQ quantized model.")
+
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
if dtype is None:
- if torch.cuda.is_available():
- dtype = torch.float16
- logger.info(
- "CUDA available. Assuming VPTQ inference on GPU and loading the model in `torch.float16`. To overwrite it, set `dtype` manually."
- )
- else:
- import vptq
-
- device_availability = getattr(vptq, "device_availability", lambda device: False)
- if device_availability("cpu") is True:
- raise RuntimeError("No GPU found. Please wait for the next release of VPTQ to use CPU inference")
- dtype = torch.float32
- logger.info("No GPU found. Assuming VPTQ inference on CPU and loading the model in `torch.float32`.")
+ dtype = torch.float16
+ logger.info(
+ "Assuming VPTQ inference on GPU and loading the model in `torch.float16`. To overwrite it, set `dtype` manually."
+ )
return dtype
def _process_model_before_weight_loading(
@@ -71,22 +63,16 @@ def _process_model_before_weight_loading(
keep_in_fp32_modules: list[str] | None = None,
**kwargs,
):
- """
- we don't have param like modules_to_not_convert to indicate which layers should not be quantized
- because `quantization_config` include the layers that should be quantized
- """
from ..integrations import replace_with_vptq_linear
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)
-
replace_with_vptq_linear(
model,
quantization_config=self.quantization_config,
modules_to_not_convert=self.modules_to_not_convert,
)
- model.config.quantization_config = self.quantization_config
@property
def is_trainable(self) -> bool:
diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py
index e05d5dbbb72f..8ae930c63182 100644
--- a/src/transformers/testing_utils.py
+++ b/src/transformers/testing_utils.py
@@ -267,6 +267,7 @@ def parse_int_from_env(key, default=None):
_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True)
_run_agent_tests = parse_flag_from_env("RUN_AGENT_TESTS", default=False)
+_run_training_tests = parse_flag_from_env("RUN_TRAINING_TESTS", default=True)
def is_staging_test(test_case):
@@ -317,6 +318,22 @@ def is_agent_test(test_case):
return pytest.mark.is_agent_test()(test_case)
+def is_training_test(test_case):
+ """
+ Decorator marking a test as a training test. If RUN_TRAINING_TESTS is set to a falsy value, those tests will be
+ skipped.
+ """
+ if not _run_training_tests:
+ return unittest.skip(reason="test is training test")(test_case)
+ else:
+ try:
+ import pytest # We don't need a hard dependency on pytest in the main library
+ except ImportError:
+ return test_case
+ else:
+ return pytest.mark.is_training_test()(test_case)
+
+
def slow(test_case):
"""
Decorator marking a test as slow.
@@ -638,6 +655,9 @@ def require_read_token(test_case):
if getattr(attr, "__require_read_token__", False):
continue
wrapped = require_read_token(attr)
+ if isinstance(inspect.getattr_static(test_case, attr_name), staticmethod):
+ # Don't accidentally bind staticmethods to `self`
+ wrapped = staticmethod(wrapped)
setattr(test_case, attr_name, wrapped)
return test_case
else:
@@ -650,10 +670,6 @@ def wrapper(*args, **kwargs):
with patch("huggingface_hub.utils._headers.get_token", return_value=token):
return test_case(*args, **kwargs)
else: # Allow running locally with the default token env variable
- # dealing with static/class methods and called by `self.xxx`
- if "staticmethod" in inspect.getsource(test_case).strip():
- if len(args) > 0 and isinstance(args[0], unittest.TestCase):
- return test_case(*args[1:], **kwargs)
return test_case(*args, **kwargs)
wrapper.__require_read_token__ = True
@@ -4078,3 +4094,222 @@ def write_file(file, content):
def read_json_file(file):
with open(file, "r") as fh:
return json.load(fh)
+
+
+# =============================================================================
+# Training CI Utilities - Logging and Memory Monitoring
+# =============================================================================
+
+
+# ANSI color codes for terminal output
+class Colors:
+ """ANSI color codes for terminal output formatting."""
+
+ RESET = "\033[0m"
+ BOLD = "\033[1m"
+ DIM = "\033[2m"
+
+ # Foreground colors
+ RED = "\033[31m"
+ GREEN = "\033[32m"
+ YELLOW = "\033[33m"
+ BLUE = "\033[34m"
+ MAGENTA = "\033[35m"
+ CYAN = "\033[36m"
+ WHITE = "\033[37m"
+
+ # Bright variants
+ BRIGHT_RED = "\033[91m"
+ BRIGHT_GREEN = "\033[92m"
+ BRIGHT_YELLOW = "\033[93m"
+ BRIGHT_BLUE = "\033[94m"
+ BRIGHT_CYAN = "\033[96m"
+
+
+class ColoredFormatter(logging.Formatter):
+ """Custom formatter that adds colors based on log level."""
+
+ LEVEL_COLORS = {
+ logging.DEBUG: Colors.DIM + Colors.CYAN,
+ logging.INFO: Colors.WHITE,
+ logging.WARNING: Colors.BRIGHT_YELLOW,
+ logging.ERROR: Colors.BRIGHT_RED,
+ logging.CRITICAL: Colors.BOLD + Colors.BRIGHT_RED,
+ }
+
+ # Loggers that should be dimmed (less important/verbose)
+ DIMMED_LOGGERS = {"httpx", "httpcore", "urllib3", "requests"}
+
+ def __init__(self, fmt: str | None = None, datefmt: str | None = None):
+ super().__init__(fmt, datefmt)
+
+ def format(self, record: logging.LogRecord) -> str:
+ # Check if this logger should be dimmed
+ is_dimmed = record.name in self.DIMMED_LOGGERS
+
+ if is_dimmed:
+ # Dim the entire log line for httpx and similar
+ timestamp = self.formatTime(record, self.datefmt)
+ message = record.getMessage()
+ return f"{Colors.DIM}{timestamp} - {record.name} - {record.levelname:8} - {message}{Colors.RESET}"
+
+ # Get color for this level
+ color = self.LEVEL_COLORS.get(record.levelno, Colors.RESET)
+
+ # Color the level name
+ levelname = record.levelname
+ colored_levelname = f"{color}{levelname:8}{Colors.RESET}"
+
+ # Color the timestamp
+ colored_time = f"{Colors.DIM}{self.formatTime(record, self.datefmt)}{Colors.RESET}"
+
+ # Color the logger name
+ colored_name = f"{Colors.BLUE}{record.name}{Colors.RESET}"
+
+ # Get message
+ message = record.getMessage()
+
+ return f"{colored_time} - {colored_name} - {colored_levelname} - {message}"
+
+
+_warn_once_logged: set[str] = set()
+
+
+def init_test_logger() -> logging.Logger:
+ """Initialize a test-specific logger with colored stderr handler and INFO level for tests.
+
+ Uses a named logger instead of root logger to avoid conflicts with pytest-xdist parallel execution.
+ Uses stderr instead of stdout to avoid deadlocks with pytest-xdist output capture.
+ """
+ logger = logging.getLogger("transformers.training_test")
+ logger.setLevel(logging.INFO)
+
+ # Only add handler if not already present (avoid duplicate handlers on repeated calls)
+ if not logger.handlers:
+ # Use stderr instead of stdout - pytest-xdist captures stdout which can cause deadlocks
+ ch = logging.StreamHandler(sys.stderr)
+ ch.setLevel(logging.INFO)
+
+ # Use colored formatter if terminal supports it, plain otherwise
+ if sys.stderr.isatty():
+ formatter = ColoredFormatter(datefmt="%Y-%m-%d %H:%M:%S")
+ else:
+ formatter = logging.Formatter(
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
+ )
+
+ ch.setFormatter(formatter)
+ logger.addHandler(ch)
+
+ logger.propagate = False # Don't propagate to root logger to avoid duplicate output
+ return logger
+
+
+def warn_once(logger_instance: logging.Logger, msg: str) -> None:
+ """Log a warning message only once per unique message.
+
+ Uses a global set to track messages that have already been logged
+ to prevent duplicate warning messages from cluttering the output.
+
+ Args:
+ logger_instance: The logger instance to use for warning.
+ msg: The warning message to log.
+ """
+ if msg not in _warn_once_logged:
+ logger_instance.warning(msg)
+ _warn_once_logged.add(msg)
+
+
+# Named tuple for passing memory stats for logging
+MemoryStats = collections.namedtuple(
+ "MemoryStats",
+ [
+ "rss_gib", # Resident Set Size in GiB
+ "rss_pct", # RSS as percentage of total memory
+ "vms_gib", # Virtual Memory Size in GiB
+ "peak_rss_gib", # Peak RSS in GiB
+ "peak_rss_pct", # Peak RSS as percentage of total memory
+ "available_gib", # Available system memory in GiB
+ "total_gib", # Total system memory in GiB
+ ],
+)
+
+
+class CPUMemoryMonitor:
+ """Monitor CPU memory usage for the current process."""
+
+ def __init__(self):
+ self.device_name = "CPU"
+ self._peak_rss = 0
+ self._process = None
+ self.total_memory = 0
+ self.total_memory_gib = 0
+
+ if is_psutil_available():
+ import psutil
+
+ self._process = psutil.Process(os.getpid())
+ mem_info = psutil.virtual_memory()
+ self.total_memory = mem_info.total
+ self.total_memory_gib = self._to_gib(self.total_memory)
+
+ def _to_gib(self, memory_in_bytes: int) -> float:
+ """Convert bytes to GiB."""
+ return memory_in_bytes / (1024 * 1024 * 1024)
+
+ def _to_pct(self, memory_in_bytes: int) -> float:
+ """Convert bytes to percentage of total memory."""
+ if self.total_memory == 0:
+ return 0.0
+ return 100.0 * memory_in_bytes / self.total_memory
+
+ def _update_peak(self) -> None:
+ """Update peak memory tracking."""
+ if self._process is not None:
+ current_rss = self._process.memory_info().rss
+ self._peak_rss = max(self._peak_rss, current_rss)
+
+ def get_stats(self) -> MemoryStats:
+ """Get current memory statistics."""
+ if not is_psutil_available():
+ return MemoryStats(0, 0, 0, 0, 0, 0, 0)
+
+ import psutil
+
+ self._update_peak()
+
+ mem_info = self._process.memory_info()
+ sys_mem = psutil.virtual_memory()
+
+ return MemoryStats(
+ rss_gib=self._to_gib(mem_info.rss),
+ rss_pct=self._to_pct(mem_info.rss),
+ vms_gib=self._to_gib(mem_info.vms),
+ peak_rss_gib=self._to_gib(self._peak_rss),
+ peak_rss_pct=self._to_pct(self._peak_rss),
+ available_gib=self._to_gib(sys_mem.available),
+ total_gib=self._to_gib(sys_mem.total),
+ )
+
+ def reset_peak_stats(self) -> None:
+ """Reset peak memory tracking."""
+ if self._process is not None:
+ self._peak_rss = self._process.memory_info().rss
+
+
+def build_cpu_memory_monitor(logger_instance: logging.Logger | None = None) -> CPUMemoryMonitor:
+ """Build and initialize a CPU memory monitor.
+
+ Args:
+ logger_instance: Optional logger to log initialization info. If None, no logging is done.
+
+ Returns:
+ CPUMemoryMonitor instance.
+ """
+ monitor = CPUMemoryMonitor()
+ if logger_instance is not None:
+ if is_psutil_available():
+ logger_instance.info(f"CPU memory monitor initialized: {monitor.total_memory_gib:.2f} GiB total")
+ else:
+ logger_instance.warning("psutil not available, memory monitoring disabled")
+ return monitor
diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py
index a0f072de7678..9670573cac67 100644
--- a/src/transformers/tokenization_utils_base.py
+++ b/src/transformers/tokenization_utils_base.py
@@ -33,7 +33,7 @@
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
import numpy as np
-from huggingface_hub import create_repo, list_repo_files
+from huggingface_hub import create_repo, is_offline_mode, list_repo_files
from packaging import version
from . import __version__
@@ -51,7 +51,6 @@
extract_commit_hash,
is_mlx_available,
is_numpy_array,
- is_offline_mode,
is_protobuf_available,
is_tokenizers_available,
is_torch_available,
diff --git a/src/transformers/tokenization_utils_tokenizers.py b/src/transformers/tokenization_utils_tokenizers.py
index 988acbf91aae..183a2cf797a3 100644
--- a/src/transformers/tokenization_utils_tokenizers.py
+++ b/src/transformers/tokenization_utils_tokenizers.py
@@ -24,6 +24,7 @@
from typing import Any, Optional, Union
import tokenizers.pre_tokenizers as pre_tokenizers_fast
+from huggingface_hub import is_offline_mode
from tokenizers import AddedToken, processors
from tokenizers import Encoding as EncodingFast
from tokenizers import Tokenizer as TokenizerFast
@@ -42,7 +43,7 @@
TextInput,
TruncationStrategy,
)
-from .utils import PaddingStrategy, add_end_docstrings, is_offline_mode, logging
+from .utils import PaddingStrategy, add_end_docstrings, logging
logger = logging.get_logger(__name__)
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index 3ae0c2dc6694..e4c0156a58fb 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -5087,14 +5087,14 @@ def create_accelerator_and_postprocess(self):
self.is_tp_enabled = False
if getattr(self.model, "tp_size", None) is not None and self.model.tp_size > 1:
self.is_tp_enabled = True
- if self.args.parallelism_config is not None:
- if is_accelerate_available("1.10.1"):
- if self.args.parallelism_config is not None:
+ if self.args.parallelism_config is None:
+ if is_accelerate_available("1.12.0"):
+ if self.args.parallelism_config is None:
from accelerate import ParallelismConfig
args["parallelism_config"] = ParallelismConfig(tp_size=self.model.tp_size)
else:
- raise ValueError("Requires accelerate>1.10.1 to use Tensor Parallelism.")
+ raise ValueError("Requires accelerate>1.12.0 to use Tensor Parallelism.")
if is_accelerate_available("1.2.0"):
# it we don't have the correct version, we will rely on env var instead that were set in TrainingArguments
diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py
index 963bed209ac5..09cb48cd7f74 100644
--- a/src/transformers/training_args.py
+++ b/src/transformers/training_args.py
@@ -596,9 +596,9 @@ class TrainingArguments:
instance of `Dataset`.
report_to (`str` or `list[str]`, *optional*, defaults to `"none"`):
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
- `"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, `"neptune"`,
- `"swanlab"`, `"tensorboard"`, `"trackio"` and `"wandb"`. Use `"all"` to report to all integrations
- installed, `"none"` for no integrations.
+ `"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, `"swanlab"`,
+ `"tensorboard"`, `"trackio"` and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"`
+ for no integrations.
project (`str`, *optional*, defaults to `"huggingface"`):
The name of the project to use for logging. Currently, only used by Trackio.
trackio_space_id (`str` or `None`, *optional*, defaults to `"trackio"`):
@@ -2386,8 +2386,8 @@ def set_logging(
report_to (`str` or `list[str]`, *optional*, defaults to `"none"`):
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`,
- `"neptune"`, `"swanlab"`, `"tensorboard"`, `"trackio"` and `"wandb"`. Use `"all"` to report to all
- integrations installed, `"none"` for no integrations.
+ `"swanlab"`, `"tensorboard"`, `"trackio"` and `"wandb"`. Use `"all"` to report to all integrations
+ installed, `"none"` for no integrations.
first_step (`bool`, *optional*, defaults to `False`):
Whether to log and evaluate the first `global_step` or not.
nan_inf_filter (`bool`, *optional*, defaults to `True`):
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index b32488d0da04..229a3c5df350 100644
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -91,7 +91,6 @@
extract_commit_hash,
has_file,
http_user_agent,
- is_offline_mode,
list_repo_templates,
try_to_load_from_cache,
)
@@ -129,6 +128,8 @@
is_datasets_available,
is_decord_available,
is_detectron2_available,
+ is_env_variable_false,
+ is_env_variable_true,
is_essentia_available,
is_faiss_available,
is_fbgemm_gpu_available,
diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py
index 72a2f245cf19..133ac8726cd2 100644
--- a/src/transformers/utils/auto_docstring.py
+++ b/src/transformers/utils/auto_docstring.py
@@ -67,6 +67,7 @@
"donut": "DonutSwinConfig",
"esmfold": "EsmConfig",
"parakeet": "ParakeetCTCConfig",
+ "lasr": "LasrCTCConfig",
}
_re_checkpoint = re.compile(r"\[(.+?)\]\((https://huggingface\.co/.+?)\)")
diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py
index 406ef76b8a32..fa6319578353 100644
--- a/src/transformers/utils/hub.py
+++ b/src/transformers/utils/hub.py
@@ -37,6 +37,7 @@
create_repo,
hf_hub_download,
hf_hub_url,
+ is_offline_mode,
list_repo_tree,
snapshot_download,
try_to_load_from_cache,
@@ -83,13 +84,6 @@ class DownloadKwargs(TypedDict, total=False):
commit_hash: str | None
-def is_offline_mode():
- # Import inside the function so test patches on `huggingface_hub.constants` are picked up.
- from huggingface_hub import constants as hf_hub_constants
-
- return hf_hub_constants.HF_HUB_OFFLINE
-
-
# Determine default cache directory.
# The best way to set the cache path is with the environment variable HF_HOME. For more details, check out this
# documentation page: https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables.
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index 732334d5f1a8..af34c54ed305 100644
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -77,6 +77,16 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[
return package_exists
+def is_env_variable_true(env_variable: str) -> bool:
+ """Detect whether `env_variable` has been set to a true value in the environment"""
+ return os.getenv(env_variable, "false").lower() in ("true", "1", "y", "yes", "on")
+
+
+def is_env_variable_false(env_variable: str) -> bool:
+ """Detect whether `env_variable` has been set to a false value in the environment"""
+ return os.getenv(env_variable, "true").lower() in ("false", "0", "n", "no", "off")
+
+
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
@@ -1298,6 +1308,34 @@ def is_torch_fx_proxy(x):
return False
+def is_jax_jitting(x):
+ """returns True if we are inside of `jax.jit` context, False otherwise.
+
+ When a torch model is being compiled with `jax.jit` using torchax,
+ the tensor that goes through the model would be an instance of
+ `torchax.tensor.Tensor`, which is a tensor subclass. This tensor has
+ a `jax` method to return the inner Jax array
+ (https://github.com/google/torchax/blob/13ce870a1d9adb2430333c27bb623469e3aea34e/torchax/tensor.py#L134).
+ Here we use ducktyping to detect if the inner jax array is a jax Tracer
+ then we are in tracing context. (See more at: https://github.com/jax-ml/jax/discussions/9241)
+
+ Args:
+ x: torch.Tensor
+
+ Returns:
+ bool: whether we are inside of jax jit tracing.
+ """
+
+ if not hasattr(x, "jax"):
+ return False
+ try:
+ import jax
+
+ return isinstance(x.jax(), jax.core.Tracer)
+ except Exception:
+ return False
+
+
def is_jit_tracing() -> bool:
try:
import torch
@@ -1317,12 +1355,14 @@ def is_cuda_stream_capturing() -> bool:
def is_tracing(tensor=None) -> bool:
- """Checks whether we are tracing a graph with dynamo (compile or export), torch.jit, torch.fx or CUDA stream capturing"""
+ """Checks whether we are tracing a graph with dynamo (compile or export), torch.jit, torch.fx, jax.jit (with torchax) or
+ CUDA stream capturing"""
# Note that `is_torchdynamo_compiling` checks both compiling and exporting (the export check is stricter and
# only checks export)
_is_tracing = is_torchdynamo_compiling() or is_jit_tracing() or is_cuda_stream_capturing()
if tensor is not None:
_is_tracing |= is_torch_fx_proxy(tensor)
+ _is_tracing |= is_jax_jitting(tensor)
return _is_tracing
diff --git a/src/transformers/video_processing_utils.py b/src/transformers/video_processing_utils.py
index 6c98439356e2..d73bbce889f1 100644
--- a/src/transformers/video_processing_utils.py
+++ b/src/transformers/video_processing_utils.py
@@ -22,7 +22,7 @@
from typing import Any, Optional, Union
import numpy as np
-from huggingface_hub import create_repo
+from huggingface_hub import create_repo, is_offline_mode
from huggingface_hub.dataclasses import validate_typed_dict
from .dynamic_module_utils import custom_object_save
@@ -44,7 +44,6 @@
TensorType,
add_start_docstrings,
copy_func,
- is_offline_mode,
is_torch_available,
is_torchcodec_available,
is_torchvision_v2_available,
diff --git a/tests/causal_lm_tester.py b/tests/causal_lm_tester.py
index 4d6a450d7976..c1f058ff8089 100644
--- a/tests/causal_lm_tester.py
+++ b/tests/causal_lm_tester.py
@@ -38,6 +38,7 @@
torch_device,
)
from .test_pipeline_mixin import PipelineTesterMixin
+from .test_training_mixin import TrainingTesterMixin
if is_torch_available():
@@ -304,7 +305,7 @@ def prepare_config_and_inputs_for_common(self):
@require_torch
-class CausalLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin):
+class CausalLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, TrainingTesterMixin):
model_tester_class = None
all_model_classes = None
pipeline_model_mapping = None
diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py
index ff2022879422..07691894203f 100644
--- a/tests/generation/test_continuous_batching.py
+++ b/tests/generation/test_continuous_batching.py
@@ -22,6 +22,7 @@
from transformers.generation.continuous_batching.continuous_api import build_attention_mask
from transformers.testing_utils import (
Expectations,
+ require_deterministic_for_xpu,
require_kernels,
require_read_token,
require_torch_accelerator,
@@ -137,6 +138,7 @@ def test_attention_mask(
f"Actual mask:\n{str_mask}"
)
+ @require_deterministic_for_xpu
def _continuous_batching_parity(
self, model_id: str, attn_implementation: str, expected_outputs: dict[str, str]
) -> None:
diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py
index 2570b34cf0eb..637d37984c8f 100644
--- a/tests/models/bart/test_modeling_bart.py
+++ b/tests/models/bart/test_modeling_bart.py
@@ -962,7 +962,7 @@ def test_xsum_summarization_same_as_fairseq(self):
" state."
""
)
- dct = tok.batch_encode_plus(
+ dct = tok(
[PGE_ARTICLE],
max_length=1024,
padding="max_length",
@@ -1188,7 +1188,7 @@ def test_cnn_summarization_same_as_fairseq(self):
" up to four years in prison. Her next court appearance is scheduled for May 18."
)
- dct = tok.batch_encode_plus(
+ dct = tok(
[FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
max_length=1024,
padding="max_length",
diff --git a/tests/models/biogpt/test_modeling_biogpt.py b/tests/models/biogpt/test_modeling_biogpt.py
index 3a1cb8c23c2a..aec7093030a8 100644
--- a/tests/models/biogpt/test_modeling_biogpt.py
+++ b/tests/models/biogpt/test_modeling_biogpt.py
@@ -335,7 +335,9 @@ def test_batch_generation(self):
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().item()
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
- output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
+ output_padded = model.generate(
+ input_ids=inputs_padded, max_length=model.generation_config.max_length - num_paddings
+ )
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py
index b9c96df3f537..56ee012aa98c 100644
--- a/tests/models/blt/test_modeling_blt.py
+++ b/tests/models/blt/test_modeling_blt.py
@@ -177,6 +177,10 @@ class BltModelTest(CausalLMModelTest, unittest.TestCase):
# used in `test_torch_compile_for_training`
_torch_compile_train_cls = BltForCausalLM if is_torch_available() else None
+ @unittest.skip("BLT model requires special handling for training overfit test")
+ def test_training_overfit(self):
+ pass
+
@pytest.mark.generate
@parameterized.expand([("greedy", 1), ("beam search", 2)])
@unittest.skip(
diff --git a/tests/models/colpali/test_modeling_colpali.py b/tests/models/colpali/test_modeling_colpali.py
index c1b25f19c348..3f42e1a28bf0 100644
--- a/tests/models/colpali/test_modeling_colpali.py
+++ b/tests/models/colpali/test_modeling_colpali.py
@@ -270,8 +270,8 @@ def test_model_integration_test(self):
ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test")
# Preprocess the examples
- batch_images = self.processor(images=ds["image"]).to(torch_device)
- batch_queries = self.processor(text=ds["query"]).to(torch_device)
+ batch_images = self.processor(images=ds["image"][:]).to(torch_device)
+ batch_queries = self.processor(text=ds["query"][:]).to(torch_device)
# Run inference
with torch.inference_mode():
diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py
index 785e543400a8..0fa47b707bfb 100644
--- a/tests/models/gemma3/test_modeling_gemma3.py
+++ b/tests/models/gemma3/test_modeling_gemma3.py
@@ -576,6 +576,11 @@ def test_model_4b_batch(self):
@require_torch_large_accelerator
def test_model_4b_crops(self):
+ # TODO: fix this test for ROCm
+ # It fails because the static cache is not working correctly for ROCm here and also for test_model_4b_batch_crops and test_model_4b_multiimage.
+ # The model generates only a few tokens (e.g., "The") and then stops early with EOS tokens
+ # due to NaN logits caused by tensor indexing issues on ROCm devices when using static cache with multimodal inputs.
+
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, dtype=torch.bfloat16).to(torch_device)
@@ -658,7 +663,7 @@ def test_model_4b_batch_crops(self):
**crop_config,
).to(torch_device)
- output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="static")
+ output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_NUM_IMAGES = 9 # 3 * (one for the origin image and two crops of images) = 9
EXPECTED_TEXTS = Expectations(
@@ -677,7 +682,7 @@ def test_model_4b_batch_crops(self):
],
("rocm", (9, 4)) : [
"user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a bright blue sky with some white clouds in the",
- 'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a'
+ 'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first set of images shows a cow on a beach, while the second set shows a street scene'
],
("rocm", (9, 5)) : [
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.',
@@ -718,7 +723,7 @@ def test_model_4b_multiimage(self):
add_generation_prompt=True,
).to(torch_device)
- output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="static")
+ output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = Expectations(
{
@@ -726,7 +731,7 @@ def test_model_4b_multiimage(self):
("cuda", (8, 0)): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt looks like a street scene in a vibrant,"],
("cuda", (8, 6)): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt appears to be a street scene in a city"],
("cuda", (9, 0)): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image!\n\nHere's a description of the scene:\n\n* **Location:**"],
- ("rocm", (9, 4)): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt appears to be a street scene in a vibrant"],
+ ("rocm", (9, 4)): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Main Features:**\n\n* **Chinese Archway:** The most prominent"],
("rocm", (9, 5)): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Main Features:**\n\n* **Chinese Archway:** The most prominent"],
}
) # fmt: skip
@@ -749,7 +754,7 @@ def test_model_1b_text_only(self):
("xpu", 3): ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a river deep,\nWith patterns hidden, secrets sleep.\nA neural net, a watchful eye,\nLearning'],
("cuda", 7): ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a silent stream,\nInto the neural net, a waking dream.\nAlgorithms hum, a coded grace,\n'],
("cuda", 8): ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a silent stream,\nInto the neural net, a waking dream.\nAlgorithms hum, a coded grace,\n'],
- ("rocm", (9, 4)): ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a silent stream,\nInto the neural net, a waking dream.\nAlgorithms hum, a coded grace,\n'],
+ ("rocm", (9, 4)): ['Write a poem about Machine Learning.\n\n---\n\nThe data streams, a boundless flow,\nA silent world, where patterns grow.'],
("rocm", (9, 5)): ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a river deep,\nWith patterns hidden, secrets sleep.\nA neural net, a watchful eye,\nLearning'],
}
) # fmt: skip
@@ -783,7 +788,7 @@ def test_model_4b_flash_attn(self):
("xpu", 3): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant island in the background. It looks like a sunny day'],
("cuda", 7): [],
("cuda", 8): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant island in the background. It looks like a sunny day'],
- ("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant island in the background. It looks like a sunny day'],
+ ("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. There are some clouds in the blue'],
("rocm", (9, 5)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach with a turquoise ocean and a distant island in the background. It looks like a sunny'],
}
) # fmt: skip
diff --git a/tests/models/internvl/test_modeling_internvl.py b/tests/models/internvl/test_modeling_internvl.py
index d4a31f5951e2..9f51aa3fc325 100644
--- a/tests/models/internvl/test_modeling_internvl.py
+++ b/tests/models/internvl/test_modeling_internvl.py
@@ -430,7 +430,14 @@ def test_qwen2_small_model_integration_batched_generate_multi_image(self):
# Check first output
decoded_output = processor.decode(output[0], skip_special_tokens=True)
# Batching seems to alter the output slightly, but it is also the case in the original implementation. This seems to be expected: https://github.com/huggingface/transformers/issues/23017#issuecomment-1649630232
- expected_output = "user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace." # fmt: skip
+ expected_outputs = Expectations(
+ {
+ ("xpu", 3): 'user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature\'s peace.',
+ ("cuda", 7): 'user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature\'s peace.',
+ ("rocm", (9, 4)): 'user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature\'s embrace.',
+ }
+ ) # fmt: skip
+ expected_output = expected_outputs.get_expectation()
self.assertEqual(
decoded_output,
expected_output,
@@ -443,6 +450,7 @@ def test_qwen2_small_model_integration_batched_generate_multi_image(self):
{
("xpu", 3): "user\n\nWhat are the differences between these two images?\nassistant\nThe images show the Statue of Liberty and the Golden Gate Bridge from different angles. Here are the differences:\n\n1. **Foreground",
("cuda", 7): "user\n\nWhat are the differences between these two images?\nassistant\nThe images show the Statue of Liberty and the Golden Gate Bridge from different angles. Here are the differences:\n\n1. **Foreground",
+ ("rocm", (9, 4)): "user\n\nWhat are the differences between these two images?\nassistant\nThe images show the Statue of Liberty and the Golden Gate Bridge from different angles. Here are the main differences:\n\n1. **",
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()
@@ -567,6 +575,7 @@ def test_qwen2_small_model_integration_interleaved_images_videos(self):
{
("xpu", 3): "user\n\n\nWhat are the differences between these two images?\nassistant\nThe images depict two distinct scenes:\n\n1. **Left Image:**\n - The Statue of Liberty is prominently featured on an",
("cuda", 7): 'user\n\n\nWhat are the differences between these two images?\nassistant\nThe images depict two distinct scenes:\n\n1. **Left Image:**\n - The Statue of Liberty is prominently featured on an',
+ ("rocm", (9, 4)): 'user\n\n\nWhat are the differences between these two images?\nassistant\nThe images depict two distinct scenes:\n\n1. **Left Image:**\n - This image features the Statue of Liberty on Liberty',
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()
@@ -582,6 +591,7 @@ def test_qwen2_small_model_integration_interleaved_images_videos(self):
{
("xpu", 3): "user\nFrame1: \nFrame2: \nFrame3: \nFrame4: \nFrame5: \nFrame6: \nFrame7: \nFrame8: \nWhat type of shot is the man performing?\nassistant\nA forehand shot",
("cuda", 7): 'user\nFrame1: \nFrame2: \nFrame3: \nFrame4: \nFrame5: \nFrame6: \nFrame7: \nFrame8: \nWhat type of shot is the man performing?\nassistant\nA forehand shot',
+ ("rocm", (9, 4)): 'user\nFrame1: \nFrame2: \nFrame3: \nFrame4: \nFrame5: \nFrame6: \nFrame7: \nFrame8: \nWhat type of shot is the man performing?\nassistant\nA forehand shot',
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()
@@ -593,9 +603,14 @@ def test_qwen2_small_model_integration_interleaved_images_videos(self):
# Check third output
decoded_output = processor.decode(output[2], skip_special_tokens=True)
- expected_output = (
- "user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace."
- )
+ expected_outputs = Expectations(
+ {
+ ("xpu", 3): 'user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature\'s peace.',
+ ("cuda", 7): 'user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature\'s peace.',
+ ("rocm", (9, 4)): 'user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature\'s embrace.',
+ }
+ ) # fmt: skip
+ expected_output = expected_outputs.get_expectation()
self.assertEqual(
decoded_output,
expected_output,
@@ -658,7 +673,7 @@ def test_llama_small_model_integration_forward(self):
("xpu", 3): [-9.8828, -0.4954, 1.4561, -10.3438, -10.3438],
("cuda", 7): [-9.8750, -0.4861, 1.4648, -10.3359, -10.3359],
("cuda", 8): [-9.8906, -0.4995, 1.4473, -10.3359, -10.3438],
- ("rocm", (9, 4)): [ -9.8828, -0.5005, 1.4697, -10.3438, -10.3438],
+ ("rocm", (9, 4)): [ -9.8672, -0.4888, 1.4648, -10.3281, -10.3281],
("rocm", (9, 5)): [ -9.8906, -0.4976, 1.4502, -10.3359, -10.3438],
}
) # fmt: skip
@@ -934,7 +949,7 @@ def test_llama_small_model_integration_interleaved_images_videos(self):
("xpu", 3): "user\n\n\nWhat are the difference between these two images?\nassistant\nI apologize for the confusion in my previous response. Upon closer inspection, the differences between the two images are:\n\n1. **",
("cuda", 7): 'user\n\n\nWhat are the difference between these two images?\nassistant\nI apologize for the confusion in my previous response. Upon closer inspection, the differences between the two images are:\n\n1. **',
("cuda", 8): 'user\n\n\nWhat are the difference between these two images?\nassistant\nI apologize for the confusion in my previous response. After re-examining the images, I can see that there are no',
- ("rocm", (9, 4)): 'user\n\n\nWhat are the difference between these two images?\nassistant\nI apologize for the confusion in my previous response. Upon closer inspection, the differences between the two images are:\n\n1. **',
+ ("rocm", (9, 4)): 'user\n\n\nWhat are the difference between these two images?\nassistant\nI apologize for the confusion in my previous response. After re-examining the images, I can see that there are no',
("rocm", (9, 5)): 'user\n\n\nWhat are the difference between these two images?\nassistant\nI apologize for the confusion in my previous response. After re-examining the images, I can see that there are no',
}
) # fmt: skip
diff --git a/tests/models/lasr/__init__.py b/tests/models/lasr/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/lasr/test_modeling_lasr.py b/tests/models/lasr/test_modeling_lasr.py
new file mode 100644
index 000000000000..4b723e715390
--- /dev/null
+++ b/tests/models/lasr/test_modeling_lasr.py
@@ -0,0 +1,390 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch Lasr model."""
+
+import tempfile
+import unittest
+
+from transformers import is_datasets_available, is_torch_available
+from transformers.testing_utils import cleanup, require_torch, slow, torch_device
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+
+
+if is_datasets_available():
+ from datasets import Audio, load_dataset
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ AutoProcessor,
+ LasrCTCConfig,
+ LasrEncoder,
+ LasrEncoderConfig,
+ LasrForCTC,
+ )
+
+
+class LasrEncoderModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=1024,
+ is_training=True,
+ hidden_size=64,
+ num_hidden_layers=2,
+ num_mel_bins=80,
+ num_attention_heads=4,
+ intermediate_size=256,
+ conv_kernel_size=8,
+ subsampling_conv_channels=32,
+ subsampling_conv_kernel_size=5,
+ subsampling_conv_stride=2,
+ ):
+ # testing suite parameters
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.num_mel_bins = num_mel_bins
+ self.is_training = is_training
+
+ # config parameters
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.conv_kernel_size = conv_kernel_size
+ self.subsampling_conv_channels = subsampling_conv_channels
+ self.subsampling_conv_kernel_size = subsampling_conv_kernel_size
+ self.subsampling_conv_stride = subsampling_conv_stride
+
+ self.num_mel_bins = num_mel_bins
+
+ # output sequence length after subsampling
+ self.output_seq_length = self._get_output_seq_length(self.seq_length)
+ self.encoder_seq_length = self.output_seq_length
+ self.key_length = self.output_seq_length
+
+ def _get_output_seq_length(self, seq_length):
+ kernel_size = self.subsampling_conv_kernel_size
+ stride = self.subsampling_conv_stride
+ num_layers = 2
+
+ input_length = seq_length
+ for _ in range(num_layers):
+ input_length = (input_length - kernel_size) // stride + 1
+
+ return input_length
+
+ def prepare_config_and_inputs(self):
+ input_features = floats_tensor([self.batch_size, self.seq_length, self.num_mel_bins])
+ attention_mask = random_attention_mask([self.batch_size, self.seq_length])
+ config = self.get_config()
+
+ return config, input_features, attention_mask
+
+ def get_config(self):
+ return LasrEncoderConfig(
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ conv_kernel_size=self.conv_kernel_size,
+ subsampling_conv_channels=self.subsampling_conv_channels,
+ subsampling_conv_kernel_size=self.subsampling_conv_kernel_size,
+ subsampling_conv_stride=self.subsampling_conv_stride,
+ num_mel_bins=self.num_mel_bins,
+ )
+
+ def create_and_check_model(self, config, input_features, attention_mask):
+ model = LasrEncoder(config=config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ result = model(input_features, attention_mask=attention_mask)
+
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, config.hidden_size)
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config, input_features, attention_mask = self.prepare_config_and_inputs()
+ inputs_dict = {
+ "input_features": input_features,
+ "attention_mask": attention_mask,
+ }
+ return config, inputs_dict
+
+ def check_ctc_loss(self, config, input_values, *args):
+ model = LasrForCTC(config=config)
+ model.to(torch_device)
+
+ # make sure that dropout is disabled
+ model.eval()
+
+ input_values = input_values[:3]
+ attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
+
+ input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
+ max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
+ labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size)
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_values[i, input_lengths[i] :] = 0.0
+ attention_mask[i, input_lengths[i] :] = 0
+
+ model.config.ctc_loss_reduction = "sum"
+ sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
+
+ model.config.ctc_loss_reduction = "mean"
+ mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
+
+ self.parent.assertTrue(isinstance(sum_loss, float))
+ self.parent.assertTrue(isinstance(mean_loss, float))
+
+
+@require_torch
+class LasrEncoderModelTest(ModelTesterMixin, unittest.TestCase):
+ all_model_classes = (LasrEncoder,) if is_torch_available() else ()
+
+ test_resize_embeddings = False
+ test_torch_exportable = True
+
+ def setUp(self):
+ self.model_tester = LasrEncoderModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=LasrEncoderConfig, has_text_modality=False)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ @unittest.skip(reason="LasrEncoder does not use inputs_embeds")
+ def test_model_get_set_embeddings(self):
+ pass
+
+
+class LasrForCTCModelTester:
+ def __init__(self, parent, encoder_kwargs=None, is_training=True, vocab_size=128, pad_token_id=0):
+ if encoder_kwargs is None:
+ encoder_kwargs = {}
+
+ self.parent = parent
+ self.encoder_model_tester = LasrEncoderModelTester(parent, **encoder_kwargs)
+ self.is_training = is_training
+
+ self.batch_size = self.encoder_model_tester.batch_size
+ self.output_seq_length = self.encoder_model_tester.output_seq_length
+ self.num_hidden_layers = self.encoder_model_tester.num_hidden_layers
+ self.seq_length = vocab_size
+ self.hidden_size = self.encoder_model_tester.hidden_size
+
+ self.vocab_size = vocab_size
+ self.pad_token_id = pad_token_id
+ self.encoder_seq_length = self.encoder_model_tester.encoder_seq_length
+
+ def prepare_config_and_inputs(self):
+ _, input_features, attention_mask = self.encoder_model_tester.prepare_config_and_inputs()
+ config = self.get_config()
+ return config, input_features, attention_mask
+
+ def get_config(self):
+ return LasrCTCConfig.from_encoder_config(
+ encoder_config=self.encoder_model_tester.get_config(),
+ vocab_size=self.vocab_size,
+ pad_token_id=self.pad_token_id,
+ )
+
+ def create_and_check_model(self, config, input_features, attention_mask):
+ model = LasrForCTC(config=config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ result = model(input_features, attention_mask=attention_mask)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.output_seq_length, self.vocab_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config, input_features, attention_mask = self.prepare_config_and_inputs()
+ inputs_dict = {
+ "input_features": input_features,
+ "attention_mask": attention_mask,
+ }
+ return config, inputs_dict
+
+ def test_ctc_loss_inference(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.encoder_model_tester.check_ctc_loss(*config_and_inputs)
+
+
+@require_torch
+class LasrForCTCModelTest(ModelTesterMixin, unittest.TestCase):
+ all_model_classes = (LasrForCTC,) if is_torch_available() else ()
+ pipeline_model_mapping = (
+ {
+ "feature-extraction": LasrEncoder,
+ "automatic-speech-recognition": LasrForCTC,
+ }
+ if is_torch_available()
+ else {}
+ )
+
+ test_attention_outputs = False
+
+ test_resize_embeddings = False
+ test_torch_exportable = True
+
+ _is_composite = True
+
+ def setUp(self):
+ self.model_tester = LasrForCTCModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=LasrCTCConfig)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ @unittest.skip(reason="LasrEncoder does not use inputs_embeds")
+ def test_model_get_set_embeddings(self):
+ pass
+
+ # Original function assumes vision+text model, so overwrite since Lasr is audio+text
+ # Below is modified from `tests/models/granite_speech/test_modeling_granite_speech.py`
+ def test_sdpa_can_dispatch_composite_models(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ if not self._is_composite:
+ self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
+
+ for model_class in self.all_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_sdpa = model_class.from_pretrained(tmpdirname)
+ model_sdpa = model_sdpa.eval().to(torch_device)
+
+ model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
+ model_eager = model_eager.eval().to(torch_device)
+ self.assertTrue(model_eager.config._attn_implementation == "eager")
+
+ for name, submodule in model_eager.named_modules():
+ class_name = submodule.__class__.__name__
+ if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
+ raise ValueError("The eager model should not have SDPA attention layers")
+
+
+class LasrForCTCIntegrationTest(unittest.TestCase):
+ _dataset = None
+
+ @classmethod
+ def setUp(cls):
+ cls.checkpoint_name = "eustlb/lasr"
+ cls.dtype = torch.bfloat16
+ cls.processor = AutoProcessor.from_pretrained(cls.checkpoint_name)
+
+ def tearDown(self):
+ cleanup(torch_device, gc_collect=True)
+
+ @classmethod
+ def _load_dataset(cls):
+ # Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process.
+ if cls._dataset is None:
+ cls._dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ cls._dataset = cls._dataset.cast_column(
+ "audio", Audio(sampling_rate=cls.processor.feature_extractor.sampling_rate)
+ )
+
+ def _load_datasamples(self, num_samples):
+ self._load_dataset()
+ ds = self._dataset
+ speech_samples = ds.sort("id")[:num_samples]["audio"]
+ return [x["array"] for x in speech_samples]
+
+ @slow
+ @unittest.skip(reason="TODO when checkpoint")
+ def test_model_integration(self):
+ # fmt: off
+ EXPECTED_TOKENS = torch.tensor([
+ [0,0,0,0,0,0,0,0,0,0,0,0,315,0,0,9,0,0,4,0,382,28,0,0,0,0,31,0,0,0,57,57,0,0,7,0,0,14,0,0,0,27,0,0,0,35,0,46,0,0,0,0,16,0,0,7,0,0,192,15,0,15,15,46,0,0,54,100,5,5,0,5,5,71,0,0,0,0,0,0,0,19,19,0,0,0,150,0,142,0,0,0,106,100,100,15,15,0,0,0,18,18,0,0,50,50,121,121,30,30,279,279,0,0,0,63,63,0,0,0,0,188,0,5,5,0,0,0,27,29,0,0,0,0,0,0,0,0,9,0,0,2,2]
+ ])
+ # fmt: on
+
+ # fmt: off
+ EXPECTED_TRANSCRIPTIONS = [
+ "Mr. Kuer is the apstle of the middle classes and we are glad to welcome his gospal."
+ ]
+ # fmt: on
+
+ samples = self._load_datasamples(1)
+ model = LasrForCTC.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device)
+ model.eval()
+ model.to(torch_device)
+
+ # -- apply
+ inputs = self.processor(samples)
+ inputs.to(torch_device, dtype=self.dtype)
+ predicted_ids = model.generate(**inputs)
+ torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKENS)
+ predicted_transcripts = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
+ self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS)
+
+ @slow
+ @unittest.skip(reason="TODO when checkpoint")
+ def test_model_integration_batched(self):
+ # fmt: off
+ EXPECTED_TOKENS = torch.tensor([
+ [0,0,0,0,0,0,0,0,0,0,0,0,315,0,0,9,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,57,0,0,7,0,0,0,0,0,0,167,0,0,0,35,0,46,0,0,0,0,16,0,0,7,0,0,192,15,0,15,15,46,0,0,54,100,5,5,0,5,5,71,71,0,0,0,0,0,0,19,19,0,0,0,150,0,142,0,0,0,106,100,100,15,15,0,0,0,0,18,0,0,50,50,121,121,30,30,279,279,0,0,0,63,63,0,0,0,0,188,0,5,5,0,0,27,0,29,0,0,0,0,0,0,0,0,0,0,0,0,9,9,0,156,156,0,229,0,90,0,13,13,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,2],
+ [0,0,0,0,0,0,0,0,0,0,0,0,0,0,117,25,25,0,0,0,57,0,0,0,0,0,315,0,0,9,9,0,0,0,382,0,0,0,0,65,0,34,34,5,0,0,0,179,0,17,17,31,0,0,0,0,0,4,0,343,0,0,0,0,0,24,24,0,0,65,65,0,228,228,0,0,22,0,0,0,0,0,304,0,0,0,0,0,63,63,0,0,0,0,0,0,0,0,113,0,8,0,65,0,0,0,0,0,0,0,0,0,0,0,0,0,9,0,0,156,156,229,229,90,90,90,13,13,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,2],
+ [0,0,0,0,0,0,0,0,0,0,0,0,0,144,0,0,0,450,450,0,5,5,0,0,294,0,0,0,0,0,0,0,0,48,0,0,0,0,0,102,102,0,0,0,149,149,0,0,0,0,0,0,91,0,35,0,0,0,198,0,0,0,0,0,136,136,11,11,5,5,56,56,0,0,0,16,16,0,0,7,0,0,0,286,286,26,26,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,64,0,0,0,0,0,0,398,68,68,35,35,21,21,11,11,5,5,0,0,19,0,0,0,4,4,74,0,86,86,0,0,0,44,49,0,10,10,39,0,0,0,0,305,0,13,21,21,22,0,0,0,0,0,0,0,360,360,0,0,0,294,294,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4,4,5,5,178,178,95,95,0,41,0,0,57,0,0,0,290,290,11,62,17,17,0,0,137,0,0,0,0,0,89,0,99,0,22,22,0,0,0,0,19,0,0,53,0,5,0,0,58,58,5,5,147,147,8,8,5,0,0,4,4,13,13,30,0,0,30,61,61,0,0,0,0,110,0,35,35,0,0,0,58,58,101,0,23,23,41,41,0,0,0,18,0,0,7,7,0,0,192,192,0,82,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,9,0,0,0,0,229,0,90,0,13,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,2],
+ [0,0,0,0,0,0,0,0,0,0,0,0,0,144,0,0,0,299,0,0,0,0,0,391,391,0,76,0,0,0,0,0,104,0,0,8,0,5,0,0,0,0,0,50,222,222,130,130,0,0,0,0,0,0,0,54,0,0,0,39,0,0,12,0,25,84,0,0,0,138,0,0,199,0,252,0,5,5,0,0,0,0,424,0,0,0,0,0,0,57,57,0,0,0,0,0,58,58,29,29,41,41,0,0,0,0,0,0,0,106,33,33,10,10,52,0,0,0,0,0,351,0,0,0,0,0,0,0,0,134,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,19,19,0,0,0,265,265,0,0,0,212,212,0,0,207,0,0,112,0,0,0,0,24,0,0,0,53,0,0,0,0,0,127,0,0,0,0,0,317,0,0,0,0,0,0,0,16,16,0,0,0,0,0,0,0,0,0,4,4,74,0,153,153,0,20,0,0,0,0,89,0,60,60,0,84,84,11,11,0,0,0,0,0,0,0,0,0,0,0,0,0,0,9,0,0,156,0,229,0,90,90,0,13,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,2],
+ ])
+ # fmt: on
+
+ # fmt: off
+ EXPECTED_TRANSCRIPTIONS = [
+ "Mr. is the postle of the middle classes and we are glad to welcome his gospal. [Echo",
+ "nor is Mr.Kter's manner less interesting than his matter. [Echo",
+ "He tells us that at thisvestive season of the year with Christmas and roseb beef looming before us similly is drawn from eating and its results occur most readily to the mind.Echo",
+ "He has grav dots whether cfedric laatetens work is really greek after all and can discover in it but little of rocky ethica. [Echo",
+ "Linel's pictures are sort of upgards and item paintings and maisons exquisite iteddles are as national as a Gingo palm. Mr. Bintckigible] fosters landscapes smile at one much in the same way that Mr. Carcker used to flash his teeth and Mr.J gives his sitter a cheerful slap in the back before he says like a shampoo in a turkeish bath next man"
+ ]
+ # fmt: on
+
+ samples = self._load_datasamples(5)
+ model = LasrForCTC.from_pretrained(
+ self.checkpoint_name,
+ torch_dtype=self.dtype,
+ device_map=torch_device,
+ )
+ model.eval()
+ model.to(torch_device)
+
+ # -- apply
+ inputs = self.processor(samples)
+ inputs.to(torch_device, dtype=self.dtype)
+ predicted_ids = model.generate(**inputs)
+ torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKENS)
+ predicted_transcripts = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
+ self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS)
diff --git a/tests/models/mistral3/test_modeling_mistral3.py b/tests/models/mistral3/test_modeling_mistral3.py
index 3a7f51642a7b..15109b2aec8c 100644
--- a/tests/models/mistral3/test_modeling_mistral3.py
+++ b/tests/models/mistral3/test_modeling_mistral3.py
@@ -355,7 +355,8 @@ def test_mistral3_integration_batched_generate(self):
expected_outputs = Expectations(
{
("xpu", 3): "Calm lake's mirror gleams,\nWhispering pines stand in silence,\nPath to peace begins.",
- ("cuda", 8): "Wooden path to calm,\nReflections whisper secrets,\nNature's peace unfolds.",
+ ("cuda", (8, 0)): "Wooden path to calm,\nReflections whisper secrets,\nNature's peace unfolds.",
+ ("cuda", (8, 6)): "Calm waters reflect\nWooden path to distant shore\nSilence in the woods",
("rocm", (9, 5)): "Calm waters reflect\nWooden path to distant shore\nSilence in the scene"
}
) # fmt: skip
@@ -432,7 +433,8 @@ def test_mistral3_integration_batched_generate_multi_image(self):
decoded_output = processor.decode(gen_tokens[0], skip_special_tokens=True)
expected_outputs = Expectations(
{
- ("cuda", 8): 'Calm waters reflect\nWooden path to distant shore\nSilence in the scene',
+ ("cuda", 8): "Calm waters reflect\nWooden path to distant shore\nPeace in nature's hold",
+ ("rocm", (9, 4)): "Calm waters reflect\nWooden path to distant shore\nSilence in the pines"
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()
@@ -448,6 +450,7 @@ def test_mistral3_integration_batched_generate_multi_image(self):
{
("xpu", 3): "Certainly! The images depict two iconic landmarks:\n\n1. The first image shows the Statue of Liberty in New York City.",
("cuda", 8): 'Certainly! The images depict two famous landmarks in the United States:\n\n1. The first image shows the Statue of Liberty,',
+ ("rocm", (9, 4)): 'Certainly! The images depict two famous landmarks in the United States:\n\n1. The first image shows the Statue of Liberty,',
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()
diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py
index 6ecef7519f8f..d12fbd0812f9 100644
--- a/tests/models/opt/test_modeling_opt.py
+++ b/tests/models/opt/test_modeling_opt.py
@@ -474,22 +474,22 @@ def test_batch_generation(self):
outputs = model.generate(
input_ids=input_ids,
attention_mask=inputs["attention_mask"].to(torch_device),
+ max_new_tokens=10,
)
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
- output_non_padded = model.generate(input_ids=inputs_non_padded)
+ output_non_padded = model.generate(input_ids=inputs_non_padded, max_new_tokens=10)
- num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().item()
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
- output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
+ output_padded = model.generate(input_ids=inputs_padded, max_new_tokens=10)
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
expected_output_sentence = [
- "Hello, my dog is a little bit of a dork.\nI'm a little bit",
- "Today, I was in the middle of a conversation with a friend about the",
+ "Hello, my dog is a little bit of a dork.\nI'm a",
+ "Today, I was in the middle of a conversation with a friend",
]
self.assertListEqual(expected_output_sentence, batch_out_sentence)
self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence])
diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py
index 1e0e2335067d..0f4e16964265 100644
--- a/tests/peft_integration/test_peft_integration.py
+++ b/tests/peft_integration/test_peft_integration.py
@@ -889,6 +889,60 @@ def test_peft_pipeline_no_warning(self):
# Generate text to verify pipeline works
_ = lora_generator(text, max_new_tokens=20)
+ def test_non_lora_load_adapter(self):
+ """
+ Check that loading a non-LoRA adapter works. Using LoKr as an example, not testing all possible PEFT methods.
+ """
+ from peft import LoKrConfig, get_peft_model
+
+ inputs = torch.randint(0, 100, (1, 10)).to(torch_device)
+ atol, rtol = 1e-4, 1e-4
+
+ for model_id in self.transformers_test_model_ids:
+ for transformers_class in self.transformers_test_model_classes:
+ model = transformers_class.from_pretrained(model_id).to(torch_device)
+ with torch.inference_mode():
+ output_base = model(inputs).logits
+
+ peft_config = LoKrConfig(init_weights=False)
+ peft_model = get_peft_model(model, peft_config)
+ with torch.inference_mode():
+ output_peft = peft_model(inputs).logits
+
+ # sanity check: should be different
+ assert not torch.allclose(output_base, output_peft, atol=atol, rtol=rtol)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ peft_model.save_pretrained(tmpdirname)
+ del model, peft_model
+
+ model = transformers_class.from_pretrained(tmpdirname).to(torch_device)
+ with torch.inference_mode():
+ output_transformers = model(inputs).logits
+ self.assertTrue(torch.allclose(output_peft, output_transformers, atol=atol, rtol=rtol))
+
+ def test_non_lora_add_adapter(self):
+ """
+ Check that adding a non-LoRA adapter works. Using LoKr as an example, not testing all possible PEFT methods.
+ """
+ from peft import LoKrConfig
+
+ inputs = torch.randint(0, 100, (1, 10)).to(torch_device)
+ atol, rtol = 1e-4, 1e-4
+
+ for model_id in self.transformers_test_model_ids:
+ for transformers_class in self.transformers_test_model_classes:
+ model = transformers_class.from_pretrained(model_id).to(torch_device)
+ with torch.inference_mode():
+ output_base = model(inputs).logits
+
+ peft_config = LoKrConfig(init_weights=False)
+ model.add_adapter(peft_config)
+ with torch.inference_mode():
+ output_peft = model(inputs).logits
+ # should be different
+ assert not torch.allclose(output_base, output_peft, atol=atol, rtol=rtol)
+
@require_peft
@require_torch
diff --git a/tests/quantization/compressed_tensors_integration/test_compressed_models.py b/tests/quantization/compressed_tensors_integration/test_compressed_models.py
index 47ab72e1e071..15d29e47f4a0 100644
--- a/tests/quantization/compressed_tensors_integration/test_compressed_models.py
+++ b/tests/quantization/compressed_tensors_integration/test_compressed_models.py
@@ -80,7 +80,9 @@ def _has_nested_attr(obj, attr_path):
if comp_decomp_obj is not None and hasattr(submodule, "weight"):
if "sparse-only" in uncompressed_model:
self.assertTrue(
- torch.equal(submodule.weight, comp_decomp_obj.weight),
+ torch.equal(
+ submodule.weight.to(torch_device), comp_decomp_obj.weight.to(torch_device)
+ ),
f"Weight mismatch for module '{name}' in sparse-only model.",
)
else:
diff --git a/tests/quantization/compressed_tensors_integration/test_compressed_tensors.py b/tests/quantization/compressed_tensors_integration/test_compressed_tensors.py
index f4e502dc73fa..461ebe5aff47 100644
--- a/tests/quantization/compressed_tensors_integration/test_compressed_tensors.py
+++ b/tests/quantization/compressed_tensors_integration/test_compressed_tensors.py
@@ -2,7 +2,13 @@
import unittest
from transformers import AutoModelForCausalLM, AutoTokenizer, CompressedTensorsConfig
-from transformers.testing_utils import backend_empty_cache, require_compressed_tensors, require_torch, torch_device
+from transformers.testing_utils import (
+ backend_empty_cache,
+ require_compressed_tensors,
+ require_deterministic_for_xpu,
+ require_torch,
+ torch_device,
+)
from transformers.utils import is_torch_available
@@ -47,22 +53,33 @@ def test_config_to_from_dict(self):
self.assertIsInstance(config_from_dict.sparsity_config, SparsityCompressionConfig)
def test_tinyllama_w8a8(self):
- expected_out = " Paris is the capital of which country?\n\n**A) 10** Paris is the capital of which country?\n\n**B) 11** Paris is the capital of which country?\n\n**C) 1"
+ expected_out = [
+ " Paris is the capital of which country?\n\n**A) 10** Paris is the capital of which country?\n\n**B) 11** Paris is the capital of which country?\n\n**C) 1",
+ " Paris is the capital of which country?\n\n** 10.** Which country is the capital of which country?\n\n** 11.** Which country is the capital of which country?\n\n** 12.", # XPU
+ ]
self._test_quantized_model(self.tinyllama_w8a8, expected_out)
def test_tinyllama_w4a16(self):
- expected_out = " Paris is the capital of which country?\nAnswer: Paris is the capital of France.\nQuestion: Which country is the capital of which city?\nAnswer: The capital of the city of New York is New York.\nQuestion: Which"
+ expected_out = [
+ " Paris is the capital of which country?\nAnswer: Paris is the capital of France.\nQuestion: Which country is the capital of which city?\nAnswer: The capital of the city of New York is New York.\nQuestion: Which"
+ ]
self._test_quantized_model(self.tinyllama_w4a16, expected_out)
def test_tinyllama_w8a16(self):
- expected_out = " Paris is the capital of which country?\nA. France\nB. Germany\nC. Spain\nD. Italy\nE. Switzerland\nQ10. Which of the following is not a country in the European Union?\nA."
+ expected_out = [
+ " Paris is the capital of which country?\nA. France\nB. Germany\nC. Spain\nD. Italy\nE. Switzerland\nQ10. Which of the following is not a country in the European Union?\nA."
+ ]
self._test_quantized_model(self.tinyllama_w8a16, expected_out)
def test_llama_8b_fp8(self):
- expected_out = "<|begin_of_text|>Paris is the capital of which country? France\nWhat is the name of the famous art museum in Paris? The Louvre\nWhat is the name of the famous bridge in Paris? Pont des Arts\nWhat is the name of the famous opera? "
+ expected_out = [
+ "<|begin_of_text|>Paris is the capital of which country? France\nWhat is the name of the famous art museum in Paris? The Louvre\nWhat is the name of the famous bridge in Paris? Pont des Arts\nWhat is the name of the famous opera? ",
+ "<|begin_of_text|>Paris is the capital of which country? France\nWhat is the name of the famous art museum in Paris? The Louvre\nWhat is the name of the famous bridge in Paris? Pont des Arts\nWhat is the name of the famous opera", # XPU
+ ]
self._test_quantized_model(self.llama3_8b_fp8, expected_out)
- def _test_quantized_model(self, model_name: str, expected_output: str):
+ @require_deterministic_for_xpu
+ def _test_quantized_model(self, model_name: str, expected_output: list):
"""Carry out generation"""
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -84,4 +101,4 @@ def _test_quantized_model(self, model_name: str, expected_output: str):
outputs = tokenizer.batch_decode(generated_ids)
self.assertIsNotNone(outputs)
- self.assertEqual(outputs[0], expected_output)
+ self.assertIn(outputs[0], expected_output)
diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py
index 417abb12bc50..694cfd2715b4 100644
--- a/tests/quantization/torchao_integration/test_torchao.py
+++ b/tests/quantization/torchao_integration/test_torchao.py
@@ -725,6 +725,7 @@ def check_serialization_expected_output(self, device, expected_output, safe_seri
dtype = torch.bfloat16 if isinstance(self.quant_scheme, Int4WeightOnlyConfig) else "auto"
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=safe_serialization)
+
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(
tmpdirname, dtype=dtype, device_map=device, torch_dtype=dtype, use_safetensors=safe_serialization
)
@@ -738,7 +739,7 @@ def test_serialization_expected_output(self):
@require_torchao
-@require_torchao_version_greater_or_equal("0.14.0")
+@require_torchao_version_greater_or_equal("0.15.0")
class TorchAoSafeSerializationTest(TorchAoSerializationTest):
# called only once for all test in this class
@classmethod
@@ -763,6 +764,16 @@ def tearDown(self):
"What are we having for dinner?\n\nJess: (smiling) I",
),
(torchao.quantization.Float8WeightOnlyConfig(), "What are we having for dinner?\n\nJessica: (smiling)"),
+ (Int4WeightOnlyConfig(), "What are we having for dinner?"),
+ (
+ Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"),
+ "What are we having for dinner?\nRed, white, and green beans,",
+ ),
+ (
+ torchao.quantization.Int8DynamicActivationIntxWeightConfig(),
+ "What are we having for dinner?\n\nJessica: (smiling)",
+ ),
+ (torchao.quantization.IntxWeightOnlyConfig(), "What are we having for dinner?\n\nJessica: (smiling)"),
]
if is_torchao_available()
else []
diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py
index 673f8def3159..a521f5119697 100644
--- a/tests/test_tokenization_common.py
+++ b/tests/test_tokenization_common.py
@@ -2302,7 +2302,9 @@ def test_batch_encode_plus_batch_sequence_length(self):
encoded_sequences = [tokenizer(sequence) for sequence in sequences]
encoded_sequences_batch = tokenizer(sequences, padding=False)
- self.assertListEqual(encoded_sequences, self.convert_batch_to_list_format(encoded_sequences_batch))
+ self.assertListEqual(
+ encoded_sequences, TokenizerTesterMixin.convert_batch_to_list_format(encoded_sequences_batch)
+ )
maximum_length = len(max([encoded_sequence["input_ids"] for encoded_sequence in encoded_sequences], key=len))
@@ -2316,7 +2318,7 @@ def test_batch_encode_plus_batch_sequence_length(self):
encoded_sequences_batch_padded = tokenizer(sequences, padding=True)
self.assertListEqual(
encoded_sequences_padded,
- self.convert_batch_to_list_format(encoded_sequences_batch_padded),
+ TokenizerTesterMixin.convert_batch_to_list_format(encoded_sequences_batch_padded),
)
# check 'longest' is unsensitive to a max length
@@ -2357,7 +2359,9 @@ def test_batch_encode_plus_padding(self):
tokenizer(sequence, max_length=max_length, padding="max_length") for sequence in sequences
]
encoded_sequences_batch = tokenizer(sequences, max_length=max_length, padding="max_length")
- self.assertListEqual(encoded_sequences, self.convert_batch_to_list_format(encoded_sequences_batch))
+ self.assertListEqual(
+ encoded_sequences, TokenizerTesterMixin.convert_batch_to_list_format(encoded_sequences_batch)
+ )
# Left padding tests
tokenizer = self.get_tokenizer(do_lower_case=False)
@@ -2377,7 +2381,9 @@ def test_batch_encode_plus_padding(self):
tokenizer(sequence, max_length=max_length, padding="max_length") for sequence in sequences
]
encoded_sequences_batch = tokenizer(sequences, max_length=max_length, padding="max_length")
- self.assertListEqual(encoded_sequences, self.convert_batch_to_list_format(encoded_sequences_batch))
+ self.assertListEqual(
+ encoded_sequences, TokenizerTesterMixin.convert_batch_to_list_format(encoded_sequences_batch)
+ )
def test_pretokenized_inputs(self):
# Test when inputs are pretokenized
diff --git a/tests/test_training_mixin.py b/tests/test_training_mixin.py
new file mode 100644
index 000000000000..f0b894ca32cb
--- /dev/null
+++ b/tests/test_training_mixin.py
@@ -0,0 +1,413 @@
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Training overfit tester mixin for model tests."""
+
+import logging
+import time
+from abc import ABC, abstractmethod
+from typing import Optional
+
+import torch
+
+from transformers import set_seed
+from transformers.testing_utils import Colors, build_cpu_memory_monitor, init_test_logger, is_training_test
+
+
+logger = logging.getLogger("transformers.training_test")
+
+
+class TrainingTesterMixin(ABC):
+ """
+ Mixin for training overfit tests. Add to model test classes alongside ModelTesterMixin.
+
+ The model_tester (e.g., CausalLMModelTester) already provides:
+ - get_config() -> tiny model config
+ - prepare_config_and_inputs_for_common() -> config + input dict
+ - causal_lm_class, base_model_class, etc.
+
+ This mixin adds training-specific tests using that infrastructure.
+ """
+
+ # ============================================================
+ # Training hyperparameters
+ # ============================================================
+ training_overfit_steps: int = 300
+ training_overfit_batch_size: int = 2
+ training_overfit_learning_rate: float = 1e-3
+ training_overfit_seq_length: int = 64
+ training_overfit_log_freq: int = 10
+
+ # Loss reduction and grad norm reduction thresholds for passing the test (i.e 95% reduction)
+ training_loss_reduction_threshold: float = 0.9
+ training_grad_norm_reduction_threshold: float = 0.9
+
+ @property
+ @abstractmethod
+ def model_tester(self):
+ """The model tester instance (e.g., CausalLMModelTester)."""
+ ...
+
+ # ============================================================
+ # Modality detection
+ # ============================================================
+ def _get_model_modality(self) -> str:
+ """Detect the modality of the model based on its input signature."""
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ if "input_ids" in inputs_dict:
+ return "text"
+ elif "pixel_values" in inputs_dict:
+ return "image"
+ elif "input_features" in inputs_dict or "input_values" in inputs_dict:
+ return "audio"
+ else:
+ raise ValueError(f"Unknown modality: {inputs_dict}")
+
+ # ============================================================
+ # Training data creation for each modality
+ # ============================================================
+ def _create_text_training_batch(
+ self,
+ batch_size: int,
+ seq_length: int,
+ vocab_size: int,
+ ) -> dict[str, torch.Tensor]:
+ """Create a simple text batch without needing a tokenizer."""
+ # Create a deterministic sequence (not random, so model can learn it)
+ pattern = list(range(1, min(20, vocab_size))) # tokens 1-19
+ num_repeats = (seq_length // len(pattern)) + 1
+ tokens = (pattern * num_repeats)[:seq_length]
+ input_ids = torch.tensor([tokens] * batch_size, dtype=torch.long)
+ return {"input_ids": input_ids, "labels": input_ids.clone()}
+
+ def _create_image_training_batch(
+ self,
+ batch_size: int,
+ num_channels: int,
+ height: int,
+ width: int,
+ ) -> dict[str, torch.Tensor]:
+ """Create fixed batch for image models using a deterministic pattern."""
+ pass
+
+ def _create_audio_training_batch(
+ self,
+ batch_size: int,
+ audio_length: int,
+ feature_size: Optional[int] = None,
+ ) -> dict[str, torch.Tensor]:
+ """Create fixed batch for audio models using a deterministic waveform."""
+ pass
+
+ def _decode_text_tokens(self, tokens: list[int], max_display: int = 40) -> str:
+ """Decode tokens to readable string (maps token IDs to letters: 1->a, 2->b, etc.)."""
+ decoded = "".join(chr(ord("a") + (t - 1) % 26) for t in tokens)
+ if len(decoded) > max_display:
+ return f"'{decoded[:max_display]}...'"
+ return f"'{decoded}'"
+
+ def _get_trainable_model_class(self):
+ """Get the model class to use for training (prefers *ForCausalLM, *ForSequenceClassification, etc.)."""
+ # Prefer model classes with a head (for computing loss)
+ if hasattr(self.model_tester, "causal_lm_class") and self.model_tester.causal_lm_class is not None:
+ return self.model_tester.causal_lm_class
+ if (
+ hasattr(self.model_tester, "sequence_classification_class")
+ and self.model_tester.sequence_classification_class is not None
+ ):
+ return self.model_tester.sequence_classification_class
+ # Fall back to first model class
+ return self.all_model_classes[0]
+
+ @is_training_test
+ def test_training_overfit(self):
+ """Test that a tiny model can overfit on a fixed batch."""
+ # Initialize logging and memory monitoring
+ init_test_logger()
+ memory_monitor = build_cpu_memory_monitor(logger)
+
+ logger.info("=" * 70)
+ logger.info(f"Starting test: {self._testMethodName}")
+ logger.info("=" * 70)
+
+ # Skip if model doesn't support training
+ if not getattr(self.model_tester, "is_training", True):
+ logger.info(f"{Colors.YELLOW}Skipping: Model tester not configured for training tests{Colors.RESET}")
+ self.skipTest("Model tester not configured for training tests")
+
+ # Configuration
+ logger.info(f"{Colors.BOLD}Job Configuration:{Colors.RESET}")
+ logger.info(f" {Colors.CYAN}total_steps:{Colors.RESET} {self.training_overfit_steps}")
+ logger.info(f" {Colors.CYAN}batch_size:{Colors.RESET} {self.training_overfit_batch_size}")
+ logger.info(f" {Colors.CYAN}learning_rate:{Colors.RESET} {self.training_overfit_learning_rate}")
+ logger.info(f" {Colors.CYAN}seq_length:{Colors.RESET} {self.training_overfit_seq_length}")
+ logger.info(f" {Colors.CYAN}log_freq:{Colors.RESET} {self.training_overfit_log_freq}")
+ logger.info(f" {Colors.CYAN}device:{Colors.RESET} cpu")
+
+ set_seed(42)
+
+ logger.info("-" * 70)
+ logger.info(f"{Colors.BOLD}Building model{Colors.RESET}")
+ load_start = time.perf_counter()
+
+ # Get tiny config from existing infrastructure
+ config = self.model_tester.get_config()
+
+ model_class = self._get_trainable_model_class()
+ model = model_class(config)
+ model.train()
+
+ load_time = time.perf_counter() - load_start
+ logger.info(f"Model loaded in {Colors.GREEN}{load_time:.3f}s{Colors.RESET}")
+
+ # Log model architecture
+ # TODO(3outeille): make sure if there is other parameters to log
+ logger.info(f"{Colors.BOLD}Model Architecture:{Colors.RESET}")
+ logger.info(f" {Colors.CYAN}model_class:{Colors.RESET} {model_class.__name__}")
+ if hasattr(config, "hidden_size"):
+ logger.info(f" {Colors.CYAN}hidden_size:{Colors.RESET} {config.hidden_size}")
+ if hasattr(config, "num_hidden_layers"):
+ logger.info(f" {Colors.CYAN}num_hidden_layers:{Colors.RESET} {config.num_hidden_layers}")
+ if hasattr(config, "num_attention_heads"):
+ logger.info(f" {Colors.CYAN}num_attention_heads:{Colors.RESET} {config.num_attention_heads}")
+ if hasattr(config, "num_key_value_heads"):
+ logger.info(f" {Colors.CYAN}num_key_value_heads:{Colors.RESET} {config.num_key_value_heads}")
+ if hasattr(config, "intermediate_size"):
+ logger.info(f" {Colors.CYAN}intermediate_size:{Colors.RESET} {config.intermediate_size}")
+ if hasattr(config, "vocab_size"):
+ logger.info(f" {Colors.CYAN}vocab_size:{Colors.RESET} {config.vocab_size}")
+ if hasattr(config, "num_experts"):
+ logger.info(f" {Colors.CYAN}num_experts:{Colors.RESET} {config.num_experts}")
+ if hasattr(config, "num_experts_per_tok"):
+ logger.info(f" {Colors.CYAN}num_experts_per_tok:{Colors.RESET} {config.num_experts_per_tok}")
+
+ # Count parameters
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ logger.info(
+ f"{Colors.CYAN}Model size:{Colors.RESET} {Colors.BRIGHT_GREEN}{total_params:,}{Colors.RESET} total parameters"
+ )
+ logger.info(
+ f"{Colors.CYAN}Trainable parameters:{Colors.RESET} {Colors.BRIGHT_GREEN}{trainable_params:,}{Colors.RESET}"
+ )
+
+ # Memory after model load
+ mem_stats = memory_monitor.get_stats()
+ logger.info(
+ f"{Colors.MAGENTA}Memory after model load:{Colors.RESET} {mem_stats.rss_gib:.2f} GiB ({mem_stats.rss_pct:.1f}%)"
+ )
+
+ logger.info("-" * 70)
+ logger.info(f"{Colors.BOLD}Creating fixed batch{Colors.RESET}")
+
+ modality = self._get_model_modality()
+ logger.info(f"{Colors.CYAN}Detected modality:{Colors.RESET} {modality}")
+ _, sample_inputs = self.model_tester.prepare_config_and_inputs_for_common()
+
+ if modality == "text":
+ # For text models, we need a tokenizer - use a simple one or create fake tokens
+ batch = self._create_text_training_batch(
+ batch_size=self.training_overfit_batch_size,
+ seq_length=self.training_overfit_seq_length,
+ vocab_size=config.vocab_size,
+ )
+ logger.info(f"{Colors.CYAN}Training pattern:{Colors.RESET} Repeating token sequence (1-19)")
+ else:
+ raise ValueError(f"Modality {modality} not supported yet for training overfit")
+
+ tokens_per_batch = self.training_overfit_batch_size * self.training_overfit_seq_length
+ logger.info(f" {Colors.CYAN}batch_size:{Colors.RESET} {self.training_overfit_batch_size}")
+ logger.info(f" {Colors.CYAN}seq_length:{Colors.RESET} {self.training_overfit_seq_length}")
+ logger.info(f" {Colors.CYAN}tokens_per_batch:{Colors.RESET} {tokens_per_batch:,}")
+ logger.info(f"{Colors.DIM}Using same fixed batch every step (deterministic overfitting){Colors.RESET}")
+
+ logger.info("-" * 70)
+ logger.info(f"{Colors.BOLD}Building optimizer{Colors.RESET}")
+
+ optimizer = torch.optim.Adam(
+ model.parameters(), lr=self.training_overfit_learning_rate, weight_decay=0.0, betas=(0.9, 0.999)
+ )
+ logger.info(f"{Colors.CYAN}Optimizer:{Colors.RESET} Adam")
+ logger.info(f" {Colors.CYAN}learning_rate:{Colors.RESET} {self.training_overfit_learning_rate}")
+ logger.info(f" {Colors.CYAN}weight_decay:{Colors.RESET} 0.0")
+ logger.info(f" {Colors.CYAN}betas:{Colors.RESET} (0.9, 0.999)")
+
+ # Training Loop
+ logger.info("-" * 70)
+ logger.info("Training starts at step 1")
+
+ initial_loss = None
+ final_loss = None
+ initial_grad_norm = None
+ final_grad_norm = None
+ training_start = time.perf_counter()
+ memory_monitor.reset_peak_stats()
+
+ for step in range(1, self.training_overfit_steps + 1):
+ step_start = time.perf_counter()
+
+ optimizer.zero_grad()
+ outputs = model(**batch)
+ loss = outputs.loss
+
+ if initial_loss is None:
+ initial_loss = loss.item()
+ final_loss = loss.item()
+
+ loss.backward()
+
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
+
+ if initial_grad_norm is None:
+ initial_grad_norm = grad_norm.item()
+ final_grad_norm = grad_norm.item()
+
+ optimizer.step()
+
+ step_time = time.perf_counter() - step_start
+
+ # Log at frequency
+ if step == 1 or step % self.training_overfit_log_freq == 0 or step == self.training_overfit_steps:
+ tokens_per_sec = tokens_per_batch / step_time
+ mem_stats = memory_monitor.get_stats()
+ logger.info(
+ f"{Colors.CYAN}step:{Colors.RESET} {step} "
+ f"{Colors.GREEN}loss:{Colors.RESET} {loss.item():7.4f} "
+ f"{Colors.YELLOW}grad_norm:{Colors.RESET} {grad_norm.item():6.4f} "
+ f"{Colors.MAGENTA}memory:{Colors.RESET} {mem_stats.rss_gib:.2f}GiB({mem_stats.rss_pct:.1f}%) "
+ f"{Colors.BLUE}tok/s:{Colors.RESET} {tokens_per_sec:,.0f} "
+ f"{Colors.DIM}step_time:{Colors.RESET} {step_time:.3f}s"
+ )
+
+ training_time = time.perf_counter() - training_start
+
+ # Training Summary
+ total_tokens = self.training_overfit_steps * tokens_per_batch
+ logger.info("-" * 70)
+ logger.info(f"{Colors.BOLD}Training completed{Colors.RESET}")
+ logger.info(f"Total training time: {training_time:.2f}s")
+ logger.info(f"Total steps: {self.training_overfit_steps}")
+ logger.info(f"Total tokens seen: {total_tokens:,}")
+ logger.info(f"Average tokens/sec: {total_tokens / training_time:,.0f}")
+
+ # Memory summary
+ mem_stats = memory_monitor.get_stats()
+ logger.info(f"{Colors.BOLD}Memory usage:{Colors.RESET}")
+ logger.info(
+ f" {Colors.CYAN}current_rss:{Colors.RESET} {mem_stats.rss_gib:.2f} GiB ({mem_stats.rss_pct:.1f}%)"
+ )
+ logger.info(
+ f" {Colors.CYAN}peak_rss:{Colors.RESET} {mem_stats.peak_rss_gib:.2f} GiB ({mem_stats.peak_rss_pct:.1f}%)"
+ )
+ logger.info(
+ f" {Colors.CYAN}available:{Colors.RESET} {mem_stats.available_gib:.2f} GiB / {mem_stats.total_gib:.2f} GiB"
+ )
+
+ # Loss analysis
+ loss_reduction = (initial_loss - final_loss) / initial_loss * 100
+ logger.info(f"{Colors.BOLD}Loss metrics:{Colors.RESET}")
+ logger.info(f" {Colors.CYAN}initial_loss:{Colors.RESET} {initial_loss:.4f}")
+ logger.info(f" {Colors.CYAN}final_loss:{Colors.RESET} {final_loss:.4f}")
+ logger.info(f" {Colors.CYAN}loss_reduction:{Colors.RESET} {loss_reduction:.1f}%")
+
+ # Grad norm analysis
+ grad_norm_reduction = (initial_grad_norm - final_grad_norm) / initial_grad_norm * 100
+ logger.info(f"{Colors.BOLD}Grad norm metrics:{Colors.RESET}")
+ logger.info(f" {Colors.CYAN}initial_grad_norm:{Colors.RESET} {initial_grad_norm:.4f}")
+ logger.info(f" {Colors.CYAN}final_grad_norm:{Colors.RESET} {final_grad_norm:.4f}")
+ logger.info(f" {Colors.CYAN}grad_norm_reduction:{Colors.RESET} {grad_norm_reduction:.1f}%")
+
+ # Generation Test (only for text/causal LM models)
+ # TODO(3outeille): handle audio and generate
+ generation_matches = None
+ if modality == "text" and hasattr(model, "generate"):
+ logger.info("-" * 70)
+ logger.info(f"{Colors.BOLD}Testing generation{Colors.RESET}")
+
+ model.eval()
+
+ # Get the expected token sequence (same pattern used in training)
+ expected_tokens = batch["input_ids"][0].tolist()
+
+ # Use first token as prompt
+ prompt_ids = torch.tensor([[expected_tokens[0]]], dtype=torch.long)
+ num_tokens_to_generate = len(expected_tokens) - 1
+
+ logger.info(f"Prompt: {self._decode_text_tokens([expected_tokens[0]])}")
+
+ with torch.no_grad():
+ generated_ids = model.generate(
+ prompt_ids,
+ max_new_tokens=num_tokens_to_generate,
+ do_sample=False,
+ pad_token_id=config.pad_token_id if hasattr(config, "pad_token_id") else 0,
+ eos_token_id=0,
+ )
+
+ generated_tokens = generated_ids[0].tolist()
+
+ # Compare generated tokens with expected tokens
+ generation_matches = generated_tokens == expected_tokens
+
+ # TODO(3outeille): handle audio and image generation
+ if generation_matches:
+ logger.info(f"Expected: {Colors.GREEN}{self._decode_text_tokens(expected_tokens)}{Colors.RESET}")
+ logger.info(f"Generated: {Colors.GREEN}{self._decode_text_tokens(generated_tokens)}{Colors.RESET}")
+ logger.info(f"{Colors.GREEN}✓ Generation matches training sequence!{Colors.RESET}")
+ else:
+ logger.info(f"Expected: {Colors.GREEN}{self._decode_text_tokens(expected_tokens)}{Colors.RESET}")
+ logger.info(f"Generated: {Colors.RED}{self._decode_text_tokens(generated_tokens)}{Colors.RESET}")
+ # Count matching tokens
+ matches = sum(1 for g, e in zip(generated_tokens, expected_tokens) if g == e)
+ logger.info(
+ f"{Colors.YELLOW}✗ Generation mismatch: {matches}/{len(expected_tokens)} tokens match{Colors.RESET}"
+ )
+
+ # Assertions
+ logger.info("-" * 70)
+ logger.info(f"{Colors.BOLD}Running assertions{Colors.RESET}")
+
+ # Assert loss decreased significantly
+ loss_reduction_ratio = (initial_loss - final_loss) / initial_loss
+ self.assertGreater(
+ loss_reduction_ratio,
+ self.training_loss_reduction_threshold,
+ f"Expected loss to decrease by at least {self.training_loss_reduction_threshold * 100:.0f}%, "
+ f"got {loss_reduction:.1f}%",
+ )
+ logger.info(
+ f"{Colors.GREEN}✓ Loss decreased by more than {self.training_loss_reduction_threshold * 100:.0f}%{Colors.RESET}"
+ )
+
+ # Assert grad_norm decreased significantly
+ grad_norm_reduction_ratio = (initial_grad_norm - final_grad_norm) / initial_grad_norm
+ self.assertGreater(
+ grad_norm_reduction_ratio,
+ self.training_grad_norm_reduction_threshold,
+ f"Expected grad_norm to decrease by at least {self.training_grad_norm_reduction_threshold * 100:.0f}%, "
+ f"got {grad_norm_reduction:.1f}%",
+ )
+ logger.info(
+ f"{Colors.GREEN}✓ Grad norm decreased by more than {self.training_grad_norm_reduction_threshold * 100:.0f}%{Colors.RESET}"
+ )
+
+ # Assert generation matches (if applicable)
+ if generation_matches is not None:
+ self.assertTrue(generation_matches, "Expected model to generate the training sequence after overfitting")
+ logger.info(f"{Colors.GREEN}✓ Generated sequence matches training sequence{Colors.RESET}")
+
+ logger.info("=" * 70)
+ logger.info(f"Finished test: {self._testMethodName}")
+ logger.info("=" * 70)
diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py
index b56728ce112b..0cf8ae251b4c 100644
--- a/tests/utils/test_modeling_utils.py
+++ b/tests/utils/test_modeling_utils.py
@@ -315,16 +315,16 @@ class TestOffline(unittest.TestCase):
def test_offline(self):
with tempfile.TemporaryDirectory() as tmpdir:
# First offline load should fail
- with patch("transformers.utils.hub.is_offline_mode", return_value=True):
+ with patch("huggingface_hub.constants.HF_HUB_OFFLINE", True):
with pytest.raises(OSError):
AutoModelForImageClassification.from_pretrained(TINY_IMAGE_CLASSIF, cache_dir=tmpdir)
# Enable online mode for download
- with patch("transformers.utils.hub.is_offline_mode", return_value=False):
+ with patch("huggingface_hub.constants.HF_HUB_OFFLINE", False):
snapshot_download(TINY_IMAGE_CLASSIF, cache_dir=tmpdir)
# Load again in offline mode - should work now
- with patch("transformers.utils.hub.is_offline_mode", return_value=True):
+ with patch("huggingface_hub.constants.HF_HUB_OFFLINE", True):
AutoModelForImageClassification.from_pretrained(TINY_IMAGE_CLASSIF, cache_dir=tmpdir)
def test_local_files_only(self):
@@ -2228,6 +2228,41 @@ def test_device_map_works_with_unexpected_keys_sharded(self):
# Unexpected keys (mtp) should be removed from the state dict, therefore this should not error out.
BaseModelWithUnexpectedKeys.from_pretrained(temp.name, device_map={"linear": "cpu", "linear_2": "disk"})
+ def test_loading_respect_env_variable_for_threading(self):
+ """Test that we can correctly control threading during loading"""
+ model = BaseModel(PreTrainedConfig())
+
+ # Monkey patch Thread.__init__ to add a counter of launched threads
+ original_init = threading.Thread.__init__
+ counter = 0
+
+ def tracking_init(self, *args, **kwargs):
+ nonlocal counter
+ counter += 1
+ original_init(self, *args, **kwargs)
+
+ threading.Thread.__init__ = tracking_init
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ # Use threading
+ os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "0"
+ before = counter
+ _ = BaseModel.from_pretrained(tmpdirname)
+ after = counter
+ self.assertTrue(after - before > 0, "Loading should have spawned new threads!")
+
+ # Deactivate threading
+ os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1"
+ before = counter
+ _ = BaseModel.from_pretrained(tmpdirname)
+ after = counter
+ self.assertTrue(after == before, "It looks like loading did spawn new threads, but it should not have!")
+
+ # Reverse monkey patch
+ threading.Thread.__init__ = original_init
+
@slow
@require_torch
diff --git a/tests/utils/test_offline.py b/tests/utils/test_offline.py
index 771dd44d0b5c..20f1690bb719 100644
--- a/tests/utils/test_offline.py
+++ b/tests/utils/test_offline.py
@@ -182,7 +182,7 @@ def test_is_offline_mode(self):
"""
Test `is_offline_mode` helper (should respect both HF_HUB_OFFLINE and legacy TRANSFORMERS_OFFLINE env vars)
"""
- load = "from transformers.utils import is_offline_mode"
+ load = "from huggingface_hub import is_offline_mode"
run = "print(is_offline_mode())"
stdout, _ = self._execute_with_env(load, run)
diff --git a/utils/add_dates.py b/utils/add_dates.py
index 6719beae4b63..4b9c3e7514ba 100644
--- a/utils/add_dates.py
+++ b/utils/add_dates.py
@@ -2,7 +2,7 @@
import os
import re
import subprocess
-from datetime import date
+from datetime import date, datetime
from huggingface_hub import paper_info
@@ -176,14 +176,82 @@ def replace_paper_links(file_path: str) -> bool:
return False
-def insert_dates(model_card_list: list[str]):
- """Insert release and commit dates into model cards"""
+def _normalize_model_card_name(model_card: str) -> str:
+ """Ensure model card has .md extension"""
+ return model_card if model_card.endswith(".md") else f"{model_card}.md"
+
+
+def _should_skip_model_card(model_card: str) -> bool:
+ """Check if model card should be skipped"""
+ return model_card in ("auto.md", "timm_wrapper.md")
+
+
+def _read_model_card_content(model_card: str) -> str:
+ """Read and return the content of a model card"""
+ file_path = os.path.join(DOCS_PATH, model_card)
+ with open(file_path, "r", encoding="utf-8") as f:
+ return f.read()
+
+
+def _get_dates_pattern_match(content: str):
+ """Search for the dates pattern in content and return match object"""
+ pattern = r"\n\*This model was released on (.*) and added to Hugging Face Transformers on (\d{4}-\d{2}-\d{2})\.\*"
+ return re.search(pattern, content)
+
+
+def _dates_differ_significantly(date1: str, date2: str) -> bool:
+ """Check if two dates differ by more than 1 day"""
+ try:
+ d1 = datetime.strptime(date1, "%Y-%m-%d")
+ d2 = datetime.strptime(date2, "%Y-%m-%d")
+ return abs((d1 - d2).days) > 1
+ except Exception:
+ return True # If dates can't be parsed, consider them different
+
+
+def check_missing_dates(model_card_list: list[str]) -> list[str]:
+ """Check which model cards are missing release dates and return their names"""
+ missing_dates = []
for model_card in model_card_list:
- if not model_card.endswith(".md"):
- model_card = f"{model_card}.md"
+ model_card = _normalize_model_card_name(model_card)
+ if _should_skip_model_card(model_card):
+ continue
- if model_card == "auto.md" or model_card == "timm_wrapper.md":
+ content = _read_model_card_content(model_card)
+ if not _get_dates_pattern_match(content):
+ missing_dates.append(model_card)
+
+ return missing_dates
+
+
+def check_incorrect_dates(model_card_list: list[str]) -> list[str]:
+ """Check which model cards have incorrect HF commit dates and return their names"""
+ incorrect_dates = []
+
+ for model_card in model_card_list:
+ model_card = _normalize_model_card_name(model_card)
+ if _should_skip_model_card(model_card):
+ continue
+
+ content = _read_model_card_content(model_card)
+ match = _get_dates_pattern_match(content)
+
+ if match:
+ existing_hf_date = match.group(2)
+ actual_hf_date = get_first_commit_date(model_name=model_card)
+
+ if _dates_differ_significantly(existing_hf_date, actual_hf_date):
+ incorrect_dates.append(model_card)
+
+ return incorrect_dates
+
+
+def insert_dates(model_card_list: list[str]):
+ """Insert or update release and commit dates in model cards"""
+ for model_card in model_card_list:
+ model_card = _normalize_model_card_name(model_card)
+ if _should_skip_model_card(model_card):
continue
file_path = os.path.join(DOCS_PATH, model_card)
@@ -193,55 +261,46 @@ def insert_dates(model_card_list: list[str]):
if links_replaced:
print(f"Updated paper links in {model_card}")
- pattern = (
- r"\n\*This model was released on (.*) and added to Hugging Face Transformers on (\d{4}-\d{2}-\d{2})\.\*"
- )
+ # Read content and ensure copyright disclaimer exists
+ content = _read_model_card_content(model_card)
+ markers = list(re.finditer(r"-->", content))
- # Check if the copyright disclaimer sections exists, if not, add one with 2025
- with open(file_path, "r", encoding="utf-8") as f:
- content = f.read()
- markers = list(re.finditer(r"-->", content)) # Dates info is placed right below this marker
if len(markers) == 0:
print(f"No marker found in {model_card}. Adding copyright disclaimer to the top.")
-
- # Add copyright disclaimer to the very top of the file
content = COPYRIGHT_DISCLAIMER + "\n\n" + content
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
markers = list(re.finditer(r"-->", content))
+ # Get dates
hf_commit_date = get_first_commit_date(model_name=model_card)
-
paper_link = get_paper_link(model_card=model_card, path=file_path)
- release_date = ""
- if not (paper_link == "No_paper" or paper_link == "blog"):
- release_date = get_release_date(paper_link)
- else:
+
+ if paper_link in ("No_paper", "blog"):
release_date = r"{release_date}"
+ else:
+ release_date = get_release_date(paper_link)
- match = re.search(pattern, content)
+ match = _get_dates_pattern_match(content)
- # If the dates info line already exists, preserve the existing release date unless it's a placeholder, and update the HF commit date if needed
+ # Update or insert the dates line
if match:
- existing_release_date = match.group(1) # The release date part
- existing_hf_date = match.group(2) # The existing HF date part
- release_date = (
- release_date
- if (existing_release_date == r"{release_date}" or existing_release_date == "None")
- else existing_release_date
- )
+ # Preserve existing release date unless it's a placeholder
+ existing_release_date = match.group(1)
+ existing_hf_date = match.group(2)
+
+ if existing_release_date not in (r"{release_date}", "None"):
+ release_date = existing_release_date
+
if existing_hf_date != hf_commit_date or existing_release_date != release_date:
- old_line = match.group(0) # Full matched line
+ old_line = match.group(0)
new_line = f"\n*This model was released on {release_date} and added to Hugging Face Transformers on {hf_commit_date}.*"
-
content = content.replace(old_line, new_line)
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
-
- # If the dates info line does not exist, add it
else:
+ # Insert new dates line after copyright marker
insert_index = markers[0].end()
-
date_info = f"\n*This model was released on {release_date} and added to Hugging Face Transformers on {hf_commit_date}.*"
content = content[:insert_index] + date_info + content[insert_index:]
with open(file_path, "w", encoding="utf-8") as f:
@@ -262,19 +321,41 @@ def get_all_model_cards():
return sorted(model_cards)
-def main(all=False, auto=True, models=None):
+def main(all=False, models=None, check_only=False):
+ if check_only:
+ # Check all model cards for missing dates
+ all_model_cards = get_all_model_cards()
+ print(f"Checking all {len(all_model_cards)} model cards for missing dates...")
+ missing_dates = check_missing_dates(all_model_cards)
+
+ # Check modified model cards for incorrect dates
+ modified_cards = get_modified_cards()
+ print(f"Checking {len(modified_cards)} modified model cards for incorrect dates...")
+ incorrect_dates = check_incorrect_dates(modified_cards)
+
+ if missing_dates or incorrect_dates:
+ problematic_cards = missing_dates + incorrect_dates
+ model_names = [card.replace(".md", "") for card in problematic_cards]
+ raise ValueError(
+ f"Missing or incorrect dates in the following model cards: {' '.join(problematic_cards)}\n"
+ f"Run `python utils/add_dates.py --models {' '.join(model_names)}` to fix them."
+ )
+ print("All dates are present and correct!")
+ return
+
+ # Determine which model cards to process
if all:
model_cards = get_all_model_cards()
print(f"Processing all {len(model_cards)} model cards from docs directory")
- elif auto:
+ elif models:
+ model_cards = models
+ print(f"Processing specified model cards: {model_cards}")
+ else:
model_cards = get_modified_cards()
if not model_cards:
print("No modified model cards found.")
return
print(f"Processing modified model cards: {model_cards}")
- else:
- model_cards = models
- print(f"Processing specified model cards: {model_cards}")
insert_dates(model_cards)
@@ -282,13 +363,10 @@ def main(all=False, auto=True, models=None):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Add release and commit dates to model cards")
group = parser.add_mutually_exclusive_group(required=False)
- group.add_argument(
- "--auto", action="store_true", help="Automatically process modified model cards from git status"
- )
group.add_argument("--models", nargs="+", help="Specify model cards to process (without .md extension)")
group.add_argument("--all", action="store_true", help="Process all model cards in the docs directory")
+ group.add_argument("--check-only", action="store_true", help="Check if the dates are already present")
- parser.set_defaults(auto=True)
args = parser.parse_args()
- main(args.all, args.auto, args.models)
+ main(args.all, args.models, args.check_only)
diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py
index 34ebf37b2fd0..f345bb9aad81 100644
--- a/utils/check_docstrings.py
+++ b/utils/check_docstrings.py
@@ -289,6 +289,9 @@ class DecoratedItem:
"JukeboxTokenizer",
"LEDConfig",
"LEDTokenizerFast",
+ "LasrEncoderConfig",
+ "LasrFeatureExtractor",
+ "LasrTokenizer",
"LayoutLMForQuestionAnswering",
"LayoutLMTokenizerFast",
"LayoutLMv2Config",
@@ -367,6 +370,7 @@ class DecoratedItem:
"OpenLlamaConfig",
"PLBartConfig",
"ParakeetCTCConfig",
+ "LasrCTCConfig",
"PegasusConfig",
"PegasusTokenizer",
"PegasusTokenizerFast",
diff --git a/utils/models_to_deprecate.py b/utils/models_to_deprecate.py
index 14565904d7e3..24313f65419d 100644
--- a/utils/models_to_deprecate.py
+++ b/utils/models_to_deprecate.py
@@ -105,6 +105,7 @@
"maskformer": ["maskformer-swin"],
"mbart": ["mbart50"],
"parakeet": ["parakeet_ctc", "parakeet_encoder"],
+ "lasr": ["lasr_ctc", "lasr_encoder"],
"perception_lm": ["perception_encoder"],
"pix2struct": ["deplot", "matcha"],
"qwen2_5_vl": ["qwen2_5_vl_text"],
diff --git a/utils/process_circleci_workflow_test_reports.py b/utils/process_circleci_workflow_test_reports.py
index eb61f6d586e5..cc828de5a905 100644
--- a/utils/process_circleci_workflow_test_reports.py
+++ b/utils/process_circleci_workflow_test_reports.py
@@ -14,6 +14,8 @@
import argparse
import json
import os
+import re
+from collections import Counter
import requests
@@ -22,64 +24,123 @@
parser = argparse.ArgumentParser()
parser.add_argument("--workflow_id", type=str, required=True)
args = parser.parse_args()
- workflow_id = args.workflow_id
r = requests.get(
- f"https://circleci.com/api/v2/workflow/{workflow_id}/job",
+ f"https://circleci.com/api/v2/workflow/{args.workflow_id}/job",
headers={"Circle-Token": os.environ.get("CIRCLE_TOKEN", "")},
)
jobs = r.json()["items"]
os.makedirs("outputs", exist_ok=True)
-
workflow_summary = {}
- # for each job, download artifacts
+ failure_entries = []
+
for job in jobs:
- project_slug = job["project_slug"]
if job["name"].startswith(("tests_", "examples_", "pipelines_")):
- url = f"https://circleci.com/api/v2/project/{project_slug}/{job['job_number']}/artifacts"
+ url = f"https://circleci.com/api/v2/project/{job['project_slug']}/{job['job_number']}/artifacts"
r = requests.get(url, headers={"Circle-Token": os.environ.get("CIRCLE_TOKEN", "")})
job_artifacts = r.json()["items"]
- os.makedirs(job["name"], exist_ok=True)
os.makedirs(f"outputs/{job['name']}", exist_ok=True)
job_test_summaries = {}
+ job_failure_lines = {}
+
for artifact in job_artifacts:
- if artifact["path"].startswith("reports/") and artifact["path"].endswith("/summary_short.txt"):
- node_index = artifact["node_index"]
- url = artifact["url"]
+ url = artifact["url"]
+ if artifact["path"].endswith("/summary_short.txt"):
+ r = requests.get(url, headers={"Circle-Token": os.environ.get("CIRCLE_TOKEN", "")})
+ job_test_summaries[artifact["node_index"]] = r.text
+ elif artifact["path"].endswith("/failures_line.txt"):
r = requests.get(url, headers={"Circle-Token": os.environ.get("CIRCLE_TOKEN", "")})
- test_summary = r.text
- job_test_summaries[node_index] = test_summary
+ job_failure_lines[artifact["node_index"]] = r.text
summary = {}
for node_index, node_test_summary in job_test_summaries.items():
for line in node_test_summary.splitlines():
if line.startswith("PASSED "):
- test = line[len("PASSED ") :]
- summary[test] = "passed"
+ summary[line[7:]] = "passed"
elif line.startswith("FAILED "):
- test = line[len("FAILED ") :].split()[0]
- summary[test] = "failed"
- # failed before passed
+ summary[line[7:].split()[0]] = "failed"
+
summary = dict(sorted(summary.items(), key=lambda x: (x[1], x[0])))
workflow_summary[job["name"]] = summary
- # collected version
with open(f"outputs/{job['name']}/test_summary.json", "w") as fp:
json.dump(summary, fp, indent=4)
+ # Collect failure details
+ for node_index, summary_text in job_test_summaries.items():
+ failure_lines_list = [
+ l.strip()
+ for l in job_failure_lines.get(node_index, "").splitlines()
+ if l.strip() and not l.strip().startswith(("=", "_", "short test summary")) and ": " in l
+ ]
+
+ failure_idx = 0
+ for line in summary_text.splitlines():
+ if line.startswith("FAILED ") and " - Failed: (subprocess)" not in line:
+ test_name, _, short_error = line[7:].strip().partition(" - ")
+ test_name = test_name.strip()
+ parts = test_name.split("::", 1)[0].split("/")
+ model_name = parts[2] if len(parts) >= 3 and test_name.startswith("tests/models/") else None
+ full_error = (
+ failure_lines_list[failure_idx] if failure_idx < len(failure_lines_list) else short_error
+ )
+
+ failure_entries.append(
+ {
+ "job_name": job["name"],
+ "test_name": test_name,
+ "short_error": short_error,
+ "error": full_error,
+ "model_name": model_name,
+ }
+ )
+ failure_idx += 1
+
+ # Build workflow summary
new_workflow_summary = {}
for job_name, job_summary in workflow_summary.items():
for test, status in job_summary.items():
- if test not in new_workflow_summary:
- new_workflow_summary[test] = {}
- new_workflow_summary[test][job_name] = status
+ new_workflow_summary.setdefault(test, {})[job_name] = status
- for test, result in new_workflow_summary.items():
- new_workflow_summary[test] = dict(sorted(result.items()))
- new_workflow_summary = dict(sorted(new_workflow_summary.items()))
+ new_workflow_summary = {
+ test: dict(sorted(result.items())) for test, result in sorted(new_workflow_summary.items())
+ }
with open("outputs/test_summary.json", "w") as fp:
json.dump(new_workflow_summary, fp, indent=4)
+
+ # Aggregate failures by test and model
+ by_test, by_model = {}, {}
+
+ for entry in failure_entries:
+ # Normalize test name
+ normalized = entry["test_name"].split("[", 1)[0]
+ parts = normalized.split("::")
+ normalized = "::".join(parts[:-1] + [re.sub(r"_\d{2,}.*$", "", parts[-1])])
+
+ by_test.setdefault(normalized, {"count": 0, "errors": Counter(), "jobs": set(), "variants": set()})
+ by_test[normalized]["count"] += 1
+ by_test[normalized]["errors"][entry["error"]] += 1
+ by_test[normalized]["jobs"].add(entry["job_name"])
+ by_test[normalized]["variants"].add(entry["test_name"])
+
+ if entry["model_name"]:
+ by_model.setdefault(entry["model_name"], {"count": 0, "errors": Counter(), "tests": set()})
+ by_model[entry["model_name"]]["count"] += 1
+ by_model[entry["model_name"]]["errors"][entry["error"]] += 1
+ by_model[entry["model_name"]]["tests"].add(entry["test_name"])
+
+ # Convert Counter and sets to dicts/lists for JSON serialization
+ for info in by_test.values():
+ info["errors"] = dict(info["errors"].most_common())
+ info["jobs"] = sorted(info["jobs"])
+ info["variants"] = sorted(info["variants"])
+ for info in by_model.values():
+ info["errors"] = dict(info["errors"].most_common())
+ info["tests"] = sorted(info["tests"])
+
+ with open("outputs/failure_summary.json", "w") as fp:
+ json.dump({"failures": failure_entries, "by_test": by_test, "by_model": by_model}, fp, indent=4)
diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py
index 0b289c065f3a..c7a9578f5192 100644
--- a/utils/tests_fetcher.py
+++ b/utils/tests_fetcher.py
@@ -1100,6 +1100,7 @@ def parse_commit_message(commit_message: str) -> dict[str, bool]:
"pipelines_torch": r"tests/models/.*/test_modeling_.*",
"tests_hub": r"tests/.*",
"tests_non_model": r"tests/[^/]*?/test_.*\.py",
+ "tests_training_ci": r"tests/models/.*/test_modeling_.*",
}