From 9c6357c7d22d6e35605a1a0c8bb9e9a4f2c054ed Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 5 Dec 2025 21:23:24 +0000 Subject: [PATCH 1/6] support saving/loading multiple sub_processor of the same kind --- src/transformers/processing_utils.py | 110 ++++++++++++++++++---- tests/models/auto/test_processor_auto.py | 113 +++++++++++++++++++++++ 2 files changed, 204 insertions(+), 19 deletions(-) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index f54ddeb1b2a6..85eb5cdd3f9d 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -130,6 +130,26 @@ def keys(self): "video_processor": "BaseVideoProcessor", } + +def _get_modality_for_attribute(attribute_name: str) -> str: + """ + Get the canonical modality type for a given attribute name. + + For example: + - "image_processor" -> "image_processor" + - "encoder_image_processor" -> "image_processor" + - "text_tokenizer" -> "tokenizer" + - "my_feature_extractor" -> "feature_extractor" + """ + for modality in MODALITY_TO_AUTOPROCESSOR_MAPPING.keys(): + if modality in attribute_name: + return modality + raise ValueError( + f"Cannot determine modality for attribute '{attribute_name}'. " + f"Attribute name must contain one of: {list(MODALITY_TO_AUTOPROCESSOR_MAPPING.keys())}" + ) + + if sys.version_info >= (3, 11): Unpack = typing.Unpack else: @@ -664,8 +684,10 @@ def check_argument_for_proper_class(self, argument_name, argument): mismatch between expected and actual class, an error is raise. Otherwise, the proper retrieved class is returned. """ - if argument_name not in MODALITY_TO_BASE_CLASS_MAPPING and "tokenizer" in argument_name: - argument_name = "tokenizer" + # If the exact attribute name is not in the mapping, use its canonical modality + # (e.g., "encoder_tokenizer" -> "tokenizer") + if argument_name not in MODALITY_TO_BASE_CLASS_MAPPING: + argument_name = _get_modality_for_attribute(argument_name) class_name = MODALITY_TO_BASE_CLASS_MAPPING.get(argument_name) if isinstance(class_name, tuple): proper_class = tuple(self.get_possibly_dynamic_module(n) for n in class_name if n is not None) @@ -696,9 +718,13 @@ def to_dict(self) -> dict[str, Any]: # extra attributes to be kept attrs_to_save += ["auto_map"] + # Remove tokenizers from output - they have their own vocab files and are saved separately. + # All other sub-processors (image_processor, feature_extractor, etc.) are kept in processor_config.json. for attribute in self.__class__.get_attributes(): - if "tokenizer" in attribute and attribute in output: - del output[attribute] + if attribute in output: + modality = _get_modality_for_attribute(attribute) + if modality == "tokenizer": + del output[attribute] if "chat_template" in output: del output["chat_template"] @@ -820,13 +846,15 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs): if hasattr(attribute, "_set_processor_class"): attribute._set_processor_class(self.__class__.__name__) - # Save the tokenizer in its own vocab file. The other attributes are saved as part of `processor_config.json` - if attribute_name == "tokenizer": - attribute.save_pretrained(save_directory) - # if a model has multiple tokenizers, save the additional tokenizers in their own folders. - # Note that the additional tokenizers must have "tokenizer" in their attribute name. - elif "tokenizer" in attribute_name: - attribute.save_pretrained(os.path.join(save_directory, attribute_name)) + modality = _get_modality_for_attribute(attribute_name) + is_primary = attribute_name == modality + if modality == "tokenizer": + # Save the tokenizer in its own vocab file. The other attributes are saved as part of `processor_config.json` + if is_primary: + attribute.save_pretrained(save_directory) + else: + # if a model has multiple tokenizers, save the additional tokenizers in their own folders. + attribute.save_pretrained(os.path.join(save_directory, attribute_name)) elif attribute._auto_class is not None: custom_object_save(attribute, save_directory, config=attribute) @@ -1394,8 +1422,9 @@ def from_pretrained( if token is not None: kwargs["token"] = token - args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs) + # Get processor_dict first so we can use it to instantiate non-tokenizer sub-processors processor_dict, kwargs = cls.get_processor_dict(pretrained_model_name_or_path, **kwargs) + args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, processor_dict, **kwargs) return cls.from_args_and_dict(args, processor_dict, **kwargs) @classmethod @@ -1406,7 +1435,7 @@ def get_attributes(cls): # don't treat audio_tokenizer as an attribute if sub_processor_type == "audio_tokenizer": continue - if sub_processor_type in MODALITY_TO_AUTOPROCESSOR_MAPPING or "tokenizer" in sub_processor_type: + if any(modality in sub_processor_type for modality in MODALITY_TO_AUTOPROCESSOR_MAPPING.keys()): attributes.append(sub_processor_type) # Legacy processors may not override `__init__` and instead expose modality @@ -1420,7 +1449,7 @@ def get_attributes(cls): inferred_attribute = attribute_name[: -len("_class")] if inferred_attribute == "audio_tokenizer": continue - if inferred_attribute in MODALITY_TO_AUTOPROCESSOR_MAPPING or "tokenizer" in inferred_attribute: + if any(modality in inferred_attribute for modality in MODALITY_TO_AUTOPROCESSOR_MAPPING.keys()): attributes.append(inferred_attribute) return attributes @@ -1448,20 +1477,36 @@ def register_for_auto_class(cls, auto_class="AutoProcessor"): cls._auto_class = auto_class @classmethod - def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, processor_dict=None, **kwargs): """ Identify and instantiate the subcomponents of Processor classes, such as image processors, tokenizers, and feature extractors. This method inspects the processor's `__init__` signature to identify parameters that correspond to known modality types (image_processor, tokenizer, feature_extractor, etc.) or contain - "tokenizer" in their name. It then uses the appropriate Auto class (AutoImageProcessor, AutoTokenizer, etc.) - from `MODALITY_TO_AUTOPROCESSOR_MAPPING` to load each subcomponent via `.from_pretrained()`. For tokenizer-like - parameters not explicitly in the mapping, the method uses AutoTokenizer with a subfolder argument. + modality names in their attribute name. + + For tokenizers: Uses the appropriate Auto class (AutoTokenizer) to load via `.from_pretrained()`. + Additional tokenizers (e.g., "decoder_tokenizer") are loaded from subfolders. + + For other sub-processors (image_processor, feature_extractor, etc.): Primary ones are loaded via + Auto class. Additional ones are instantiated from the config stored in processor_config.json + (passed as processor_dict). + + Args: + pretrained_model_name_or_path: Path or model id to load from. + processor_dict: Optional dict containing processor config (from processor_config.json). + Required when loading additional non-tokenizer sub-processors. """ args = [] + processor_dict = processor_dict if processor_dict is not None else {} + # get args from processor init signature sub_processors = cls.get_attributes() for sub_processor_type in sub_processors: - if sub_processor_type in MODALITY_TO_AUTOPROCESSOR_MAPPING: + modality = _get_modality_for_attribute(sub_processor_type) + is_primary = sub_processor_type == modality + + if is_primary: + # Primary non-tokenizer sub-processor: load via Auto class auto_processor_class = MODALITY_TO_AUTOPROCESSOR_MAPPING[sub_processor_type] sub_processor = auto_processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs) args.append(sub_processor) @@ -1474,6 +1519,33 @@ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs) ) args.append(sub_processor) + elif sub_processor_type in processor_dict: + # Additional non-tokenizer sub-processor: instantiate from config in processor_dict + sub_processor_config = processor_dict[sub_processor_type] + if isinstance(sub_processor_config, dict): + # Determine the class to instantiate + # Image processors have 'image_processor_type', feature extractors have 'feature_extractor_type' + type_key = f"{modality}_type" + class_name = sub_processor_config.get(type_key) + if class_name is None: + raise ValueError( + f"Cannot instantiate {sub_processor_type}: missing '{type_key}' in config. " + f"Config keys: {list(sub_processor_config.keys())}" + ) + processor_class = cls.get_possibly_dynamic_module(class_name) + sub_processor = processor_class(**sub_processor_config) + args.append(sub_processor) + else: + raise ValueError( + f"Expected dict for {sub_processor_type} in processor_config.json, " + f"got {type(sub_processor_config)}" + ) + else: + raise ValueError( + f"Cannot find config for {sub_processor_type} in processor_config.json. " + f"Available keys: {list(processor_dict.keys())}" + ) + return args @staticmethod diff --git a/tests/models/auto/test_processor_auto.py b/tests/models/auto/test_processor_auto.py index 63f28d3dea9d..4e618ea0f9b5 100644 --- a/tests/models/auto/test_processor_auto.py +++ b/tests/models/auto/test_processor_auto.py @@ -35,6 +35,7 @@ AutoTokenizer, BaseVideoProcessor, BertTokenizer, + CLIPImageProcessorFast, FeatureExtractionMixin, ImageProcessingMixin, LlamaTokenizer, @@ -42,6 +43,7 @@ LlavaProcessor, ProcessorMixin, SiglipImageProcessor, + SiglipImageProcessorFast, Wav2Vec2Config, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, @@ -431,6 +433,117 @@ def test_auto_processor_save_load(self): second_processor = AutoProcessor.from_pretrained(tmp_dir) self.assertEqual(second_processor.__class__.__name__, processor.__class__.__name__) + def test_processor_with_multiple_tokenizers_save_load(self): + """Test that processors with multiple tokenizers save and load correctly.""" + + class DualTokenizerProcessor(ProcessorMixin): + """A processor with two tokenizers and an image processor.""" + + def __init__(self, tokenizer, decoder_tokenizer, image_processor): + super().__init__(tokenizer, decoder_tokenizer, image_processor) + + # Create processor with multiple tokenizers + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertForMaskedLM") + decoder_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + image_processor = SiglipImageProcessor() + + processor = DualTokenizerProcessor( + tokenizer=tokenizer, + decoder_tokenizer=decoder_tokenizer, + image_processor=image_processor, + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + processor.save_pretrained(tmp_dir) + + # Verify directory structure: primary tokenizer in root, additional in subfolder + self.assertTrue(os.path.exists(os.path.join(tmp_dir, "tokenizer_config.json"))) + self.assertTrue(os.path.isdir(os.path.join(tmp_dir, "decoder_tokenizer"))) + self.assertTrue(os.path.exists(os.path.join(tmp_dir, "decoder_tokenizer", "tokenizer_config.json"))) + + # Verify processor_config.json contains image_processor but not tokenizers + with open(os.path.join(tmp_dir, "processor_config.json")) as f: + processor_config = json.load(f) + self.assertIn("image_processor", processor_config) + self.assertNotIn("tokenizer", processor_config) + self.assertNotIn("decoder_tokenizer", processor_config) + + # Reload the full processor and verify all attributes + loaded_processor = DualTokenizerProcessor.from_pretrained(tmp_dir) + + # Verify the processor has all expected attributes + self.assertTrue(hasattr(loaded_processor, "tokenizer")) + self.assertTrue(hasattr(loaded_processor, "decoder_tokenizer")) + self.assertTrue(hasattr(loaded_processor, "image_processor")) + + # Verify tokenizers loaded correctly + self.assertEqual(loaded_processor.tokenizer.vocab_size, tokenizer.vocab_size) + self.assertEqual(loaded_processor.decoder_tokenizer.vocab_size, decoder_tokenizer.vocab_size) + + # Verify image processor loaded correctly + self.assertEqual(loaded_processor.image_processor.size, image_processor.size) + + def test_processor_with_multiple_image_processors_save_load(self): + """Test that processors with multiple image processors save and load correctly.""" + + class DualImageProcessorProcessor(ProcessorMixin): + """A processor with two image processors and a tokenizer.""" + + def __init__(self, tokenizer, image_processor, encoder_image_processor): + super().__init__(tokenizer, image_processor, encoder_image_processor) + + # Create processor with multiple image processors + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertForMaskedLM") + image_processor = SiglipImageProcessorFast(size={"height": 224, "width": 224}) + encoder_image_processor = CLIPImageProcessorFast(size={"height": 384, "width": 384}) + + processor = DualImageProcessorProcessor( + tokenizer=tokenizer, + image_processor=image_processor, + encoder_image_processor=encoder_image_processor, + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + processor.save_pretrained(tmp_dir) + + # Verify processor_config.json contains both image processors + with open(os.path.join(tmp_dir, "processor_config.json")) as f: + processor_config = json.load(f) + self.assertIn("image_processor", processor_config) + self.assertIn("encoder_image_processor", processor_config) + self.assertNotIn("tokenizer", processor_config) + + # Verify both image processors have the correct type key for instantiation + self.assertIn("image_processor_type", processor_config["image_processor"]) + self.assertIn("image_processor_type", processor_config["encoder_image_processor"]) + self.assertEqual(processor_config["image_processor"]["image_processor_type"], "SiglipImageProcessorFast") + self.assertEqual( + processor_config["encoder_image_processor"]["image_processor_type"], "CLIPImageProcessorFast" + ) + + # Verify the sizes are different (to ensure they're separate configs) + self.assertEqual(processor_config["image_processor"]["size"], {"height": 224, "width": 224}) + self.assertEqual(processor_config["encoder_image_processor"]["size"], {"height": 384, "width": 384}) + + # Reload the full processor and verify all attributes + loaded_processor = DualImageProcessorProcessor.from_pretrained(tmp_dir) + + # Verify the processor has all expected attributes + self.assertTrue(hasattr(loaded_processor, "tokenizer")) + self.assertTrue(hasattr(loaded_processor, "image_processor")) + self.assertTrue(hasattr(loaded_processor, "encoder_image_processor")) + + # Verify tokenizer loaded correctly + self.assertEqual(loaded_processor.tokenizer.vocab_size, tokenizer.vocab_size) + + # Verify image processors loaded correctly with their distinct sizes + self.assertEqual(loaded_processor.image_processor.size, {"height": 224, "width": 224}) + self.assertEqual(loaded_processor.encoder_image_processor.size, {"height": 384, "width": 384}) + + # Verify they are different types + self.assertIsInstance(loaded_processor.image_processor, SiglipImageProcessorFast) + self.assertIsInstance(loaded_processor.encoder_image_processor, CLIPImageProcessorFast) + @is_staging_test class ProcessorPushToHubTester(unittest.TestCase): From f3bd01c9156559ce829a7ba3a7e14c82cecaa985 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 5 Dec 2025 21:41:06 +0000 Subject: [PATCH 2/6] standardize all processors --- .../models/audioflamingo3/processing_audioflamingo3.py | 4 ---- src/transformers/models/auto/feature_extraction_auto.py | 1 + src/transformers/models/auto/processing_auto.py | 2 ++ src/transformers/models/auto/tokenization_auto.py | 3 +++ .../models/phi4_multimodal/processing_phi4_multimodal.py | 2 -- src/transformers/models/pix2struct/processing_pix2struct.py | 4 ---- 6 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/audioflamingo3/processing_audioflamingo3.py b/src/transformers/models/audioflamingo3/processing_audioflamingo3.py index bc14f0d6cde4..b53dcd165464 100644 --- a/src/transformers/models/audioflamingo3/processing_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/processing_audioflamingo3.py @@ -74,10 +74,6 @@ class AudioFlamingo3Processor(ProcessorMixin): Special token used to represent audio inputs in the chat template. """ - attributes = ["feature_extractor", "tokenizer"] - feature_extractor_class = "WhisperFeatureExtractor" - tokenizer_class = "Qwen2TokenizerFast" - def __init__( self, feature_extractor, diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index a9008af06ab6..6963447b5b6f 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -38,6 +38,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict( [ ("audio-spectrogram-transformer", "ASTFeatureExtractor"), + ("audioflamingo3", "WhisperFeatureExtractor"), ("clap", "ClapFeatureExtractor"), ("clvp", "ClvpFeatureExtractor"), ("csm", "EncodecFeatureExtractor"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 6d08bf37ebab..88dde801bba3 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -93,6 +93,8 @@ ("kosmos-2", "Kosmos2Processor"), ("kosmos-2.5", "Kosmos2_5Processor"), ("kyutai_speech_to_text", "KyutaiSpeechToTextProcessor"), + ("lasr_ctc", "LasrProcessor"), + ("lasr_encoder", "LasrProcessor"), ("layoutlmv2", "LayoutLMv2Processor"), ("layoutlmv3", "LayoutLMv3Processor"), ("layoutxlm", "LayoutXLMProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 31c6a783726b..bf4de43e30df 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -70,6 +70,7 @@ ("align", "BertTokenizer" if is_tokenizers_available() else None), ("arcee", "LlamaTokenizerFast" if is_tokenizers_available() else None), ("aria", "LlamaTokenizerFast" if is_tokenizers_available() else None), + ("audioflamingo3", "Qwen2TokenizerFast" if is_tokenizers_available() else None), ("aya_vision", "CohereTokenizer" if is_tokenizers_available() else None), ("bark", "BertTokenizer" if is_tokenizers_available() else None), ("bart", "RobertaTokenizer" if is_tokenizers_available() else None), @@ -183,6 +184,8 @@ ("jetmoe", "LlamaTokenizerFast" if is_tokenizers_available() else None), ("kosmos-2", "XLMRobertaTokenizer" if is_tokenizers_available() else None), ("kosmos-2.5", "PreTrainedTokenizerFast" if is_tokenizers_available() else None), + ("lasr_ctc", "ParakeetTokenizerFast" if is_tokenizers_available() else None), + ("lasr_encoder", "ParakeetTokenizerFast" if is_tokenizers_available() else None), ("layoutlm", "BertTokenizer" if is_tokenizers_available() else None), ("layoutlmv2", "LayoutLMv2Tokenizer" if is_tokenizers_available() else None), ("layoutlmv3", "LayoutLMv3Tokenizer" if is_tokenizers_available() else None), diff --git a/src/transformers/models/phi4_multimodal/processing_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/processing_phi4_multimodal.py index 8eec69b0448e..cde089821878 100644 --- a/src/transformers/models/phi4_multimodal/processing_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/processing_phi4_multimodal.py @@ -58,8 +58,6 @@ class Phi4MultimodalProcessor(ProcessorMixin): The fake audio token pattern. """ - audio_processor_class = "Phi4MultimodalFeatureExtractor" - def __init__( self, image_processor, diff --git a/src/transformers/models/pix2struct/processing_pix2struct.py b/src/transformers/models/pix2struct/processing_pix2struct.py index 1fe236339a7c..3ce09bf9d7fc 100644 --- a/src/transformers/models/pix2struct/processing_pix2struct.py +++ b/src/transformers/models/pix2struct/processing_pix2struct.py @@ -61,10 +61,6 @@ class Pix2StructProcessor(ProcessorMixin): An instance of ['T5Tokenizer`]. The tokenizer is a required input. """ - attributes = ["image_processor", "tokenizer"] - image_processor_class = "Pix2StructImageProcessor" - tokenizer_class = ("T5Tokenizer",) - def __init__(self, image_processor, tokenizer): tokenizer.return_token_type_ids = False super().__init__(image_processor, tokenizer) From c84b5642dace2eeb70e08c15e6eaa74dda492154 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 5 Dec 2025 21:46:36 +0000 Subject: [PATCH 3/6] remove tokenizer_class from lasr --- src/transformers/models/lasr/processing_lasr.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/lasr/processing_lasr.py b/src/transformers/models/lasr/processing_lasr.py index 3396986866e2..7a4661c6a6ce 100644 --- a/src/transformers/models/lasr/processing_lasr.py +++ b/src/transformers/models/lasr/processing_lasr.py @@ -47,8 +47,6 @@ class LasrProcessorKwargs(ProcessingKwargs, total=False): class LasrProcessor(ProcessorMixin): - tokenizer_class = "ParakeetTokenizerFast" - def __init__(self, feature_extractor, tokenizer): super().__init__(feature_extractor, tokenizer) From abd038d1886b7f759f593ece420a0df630bbde9c Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 5 Dec 2025 21:51:41 +0000 Subject: [PATCH 4/6] fix modular --- src/transformers/models/lasr/modular_lasr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/lasr/modular_lasr.py b/src/transformers/models/lasr/modular_lasr.py index c02b2ae0f1c3..75170f0009a5 100644 --- a/src/transformers/models/lasr/modular_lasr.py +++ b/src/transformers/models/lasr/modular_lasr.py @@ -97,7 +97,7 @@ def _decode( class LasrProcessor(ParakeetProcessor): - tokenizer_class = "ParakeetTokenizerFast" + pass class LasrEncoderConfig(ParakeetEncoderConfig): From 114a48bf658262b37aa2911cfacf0d71107547ac Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 5 Dec 2025 23:01:02 +0000 Subject: [PATCH 5/6] fix kwargs logic --- src/transformers/processing_utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 85eb5cdd3f9d..d42aa05bd4c9 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -1423,9 +1423,9 @@ def from_pretrained( kwargs["token"] = token # Get processor_dict first so we can use it to instantiate non-tokenizer sub-processors - processor_dict, kwargs = cls.get_processor_dict(pretrained_model_name_or_path, **kwargs) + processor_dict, instantiation_kwargs = cls.get_processor_dict(pretrained_model_name_or_path, **kwargs) args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, processor_dict, **kwargs) - return cls.from_args_and_dict(args, processor_dict, **kwargs) + return cls.from_args_and_dict(args, processor_dict, **instantiation_kwargs) @classmethod def get_attributes(cls): @@ -1498,6 +1498,8 @@ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, processor """ args = [] processor_dict = processor_dict if processor_dict is not None else {} + # Remove subfolder from kwargs to avoid duplicate keyword arguments + subfolder = kwargs.pop("subfolder", "") # get args from processor init signature sub_processors = cls.get_attributes() @@ -1508,14 +1510,17 @@ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, processor if is_primary: # Primary non-tokenizer sub-processor: load via Auto class auto_processor_class = MODALITY_TO_AUTOPROCESSOR_MAPPING[sub_processor_type] - sub_processor = auto_processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + sub_processor = auto_processor_class.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, **kwargs + ) args.append(sub_processor) elif "tokenizer" in sub_processor_type: # Special case: tokenizer-like parameters not in the mapping (e.g., "protein_tokenizer") # Load using AutoTokenizer with subfolder auto_processor_class = MODALITY_TO_AUTOPROCESSOR_MAPPING["tokenizer"] + tokenizer_subfolder = os.path.join(subfolder, sub_processor_type) if subfolder else sub_processor_type sub_processor = auto_processor_class.from_pretrained( - pretrained_model_name_or_path, subfolder=sub_processor_type, **kwargs + pretrained_model_name_or_path, subfolder=tokenizer_subfolder, **kwargs ) args.append(sub_processor) From ed400c3b6a7f2b4b8a76e3427c858a4c973b360f Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 9 Dec 2025 17:08:05 +0000 Subject: [PATCH 6/6] override _load_tokenizer_from_pretrained in pixtral and fuyu --- .../models/fuyu/processing_fuyu.py | 16 +++++++ .../models/pixtral/processing_pixtral.py | 19 ++++++++ src/transformers/processing_utils.py | 44 +++++++++---------- 3 files changed, 56 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/fuyu/processing_fuyu.py b/src/transformers/models/fuyu/processing_fuyu.py index c6bcc0c35d4f..fecd9d28fcf3 100644 --- a/src/transformers/models/fuyu/processing_fuyu.py +++ b/src/transformers/models/fuyu/processing_fuyu.py @@ -347,6 +347,22 @@ class FuyuProcessor(ProcessorMixin): The tokenizer is a required input. """ + @classmethod + def _load_tokenizer_from_pretrained( + cls, sub_processor_type, pretrained_model_name_or_path, subfolder="", **kwargs + ): + """ + Override for BC. Fuyu uses TokenizersBackend and requires token_type_ids to be removed from model_input_names + because Fuyu uses mm_token_type_ids instead for multimodal token identification. ` + """ + from ...tokenization_utils_tokenizers import TokenizersBackend + + tokenizer = TokenizersBackend.from_pretrained(pretrained_model_name_or_path, **kwargs) + # Remove token_type_ids as Fuyu uses mm_token_type_ids instead + if "token_type_ids" in tokenizer.model_input_names: + tokenizer.model_input_names.remove("token_type_ids") + return tokenizer + def __init__(self, image_processor, tokenizer, **kwargs): super().__init__(image_processor=image_processor, tokenizer=tokenizer) self.image_processor = image_processor diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index b62deee98300..7b898242102d 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -87,6 +87,25 @@ class PixtralProcessor(ProcessorMixin): Special token used to denote the end of an image input. """ + @classmethod + def _load_tokenizer_from_pretrained( + cls, sub_processor_type, pretrained_model_name_or_path, subfolder="", **kwargs + ): + """ + Override for BC. Pixtral requires a modified pre_tokenizer with ByteLevel prepended to handle + the specific tokenization format expected by pretrained Pixtral models. + """ + from tokenizers import pre_tokenizers + + from ...models.llama import LlamaTokenizer + + tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) + # Add ByteLevel pre_tokenizer before the existing one + tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Sequence( + [pre_tokenizers.ByteLevel(False), tokenizer._tokenizer.pre_tokenizer] + ) + return tokenizer + def __init__( self, image_processor=None, diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index bfa85843efaa..90ec745dc9db 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -1475,6 +1475,24 @@ def register_for_auto_class(cls, auto_class="AutoProcessor"): cls._auto_class = auto_class + @classmethod + def _load_tokenizer_from_pretrained( + cls, sub_processor_type, pretrained_model_name_or_path, subfolder="", **kwargs + ): + auto_processor_class = MODALITY_TO_AUTOPROCESSOR_MAPPING["tokenizer"] + is_primary = sub_processor_type == "tokenizer" + + if is_primary: + # Primary tokenizer: load from root + tokenizer = auto_processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + else: + # Additional tokenizer: load from subfolder (e.g., "decoder_tokenizer") + tokenizer_subfolder = os.path.join(subfolder, sub_processor_type) if subfolder else sub_processor_type + tokenizer = auto_processor_class.from_pretrained( + pretrained_model_name_or_path, subfolder=tokenizer_subfolder, **kwargs + ) + return tokenizer + @classmethod def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, processor_dict=None, **kwargs): """ @@ -1505,21 +1523,10 @@ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, processor for sub_processor_type in sub_processors: modality = _get_modality_for_attribute(sub_processor_type) is_primary = sub_processor_type == modality - if "FuyuProcessor" in cls.__name__ and "tokenizer" in sub_processor_type: - from .tokenization_utils_tokenizers import TokenizersBackend - tokenizer = TokenizersBackend.from_pretrained(pretrained_model_name_or_path, **kwargs) - if "token_type_ids" in tokenizer.model_input_names: - tokenizer.model_input_names.remove("token_type_ids") - args.append(tokenizer) - elif "PixtralProcessor" in cls.__name__ and "tokenizer" in sub_processor_type: - from tokenizers import pre_tokenizers - - from .models.llama import LlamaTokenizer - - tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) - tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Sequence( - [pre_tokenizers.ByteLevel(False), tokenizer._tokenizer.pre_tokenizer] + if "tokenizer" in sub_processor_type: + tokenizer = cls._load_tokenizer_from_pretrained( + sub_processor_type, pretrained_model_name_or_path, subfolder=subfolder, **kwargs ) args.append(tokenizer) elif is_primary: @@ -1529,15 +1536,6 @@ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, processor pretrained_model_name_or_path, subfolder=subfolder, **kwargs ) args.append(sub_processor) - elif "tokenizer" in sub_processor_type: - # Special case: tokenizer-like parameters not in the mapping (e.g., "protein_tokenizer") - # Load using AutoTokenizer with subfolder - auto_processor_class = MODALITY_TO_AUTOPROCESSOR_MAPPING["tokenizer"] - tokenizer_subfolder = os.path.join(subfolder, sub_processor_type) if subfolder else sub_processor_type - sub_processor = auto_processor_class.from_pretrained( - pretrained_model_name_or_path, subfolder=tokenizer_subfolder, **kwargs - ) - args.append(sub_processor) elif sub_processor_type in processor_dict: # Additional non-tokenizer sub-processor: instantiate from config in processor_dict