From d99edae121812a25f369cfa8133f3345d92f0e54 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 4 Dec 2025 13:33:20 -0500 Subject: [PATCH] allow registration of custom checkpoint conversion mappings --- src/transformers/conversion_mapping.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 5968bd08d406..115e04d70ff7 100644 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -186,10 +186,22 @@ def _build_checkpoint_conversion_mapping(): def get_checkpoint_conversion_mapping(model_type): global _checkpoint_conversion_mapping_cache - _checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping() + if _checkpoint_conversion_mapping_cache is None: + _checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping() return deepcopy(_checkpoint_conversion_mapping_cache.get(model_type)) +def register_checkpoint_conversion_mapping( + model_type: str, mapping: list[WeightConverter | WeightRenaming], overwrite: bool = False +) -> None: + global _checkpoint_conversion_mapping_cache + if _checkpoint_conversion_mapping_cache is None: + _checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping() + if model_type in _checkpoint_conversion_mapping_cache and not overwrite: + raise ValueError(f"Model type {model_type} already exists in the checkpoint conversion mapping.") + _checkpoint_conversion_mapping_cache[model_type] = mapping + + # DO NOT MODIFY, KEPT FOR BC ONLY VLMS = [ "aria",