From 79a07cc4f2d60c99428fca49e507ac4bb6968fc5 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 5 Dec 2025 18:52:44 +0100 Subject: [PATCH 1/2] avoid loading unecesary params --- src/transformers/core_model_loading.py | 44 +++++++++++++++----------- tests/utils/test_modeling_utils.py | 20 ++++++++++++ 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 01bb9c3770b4..c69d954e77ab 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -644,22 +644,17 @@ def set_param_for_module( setattr(module_obj, param_name, param_value) -def offload_and_maybe_resave_param( +def offload_and_resave_param( target_name: str, param: torch.Tensor, missing_keys: MutableSet[str], disk_offload_folder: str, disk_offload_index: dict, - applied_ops: WeightConverter | WeightRenaming, ) -> dict: - """Takes care of correctly offloading `param`. If it's not already present in the `disk_offload_index`, or if any - WeightConverter operations have been applied, it will resave the new parameter. Otherwise, it will use the original - `disk_offload_index` for this given param.""" + """Takes care of correctly offloading `param`. It will resave the new parameter, and update the index.""" # We need to remove from missing keys missing_keys.discard(target_name) - # If not already offloaded, or if we applied any special Operation except Renaming, we need to re-save - if target_name not in disk_offload_index or isinstance(applied_ops, WeightConverter): - disk_offload_index = offload_weight(param, target_name, disk_offload_folder, disk_offload_index) + disk_offload_index = offload_weight(param, target_name, disk_offload_folder, disk_offload_index) return disk_offload_index @@ -856,10 +851,11 @@ def convert_and_load_state_dict_in_model( if source_pattern is not None: new_converter = deepcopy(pattern_to_converter[source_pattern]) # each target key gets its own converter instance - mapping = param_name_to_load.setdefault(renamed_key, new_converter) + mapping = param_name_to_load.get(renamed_key, new_converter) # Otherwise, only potential renaming else: - mapping = param_name_to_load.setdefault(renamed_key, WeightRenaming(original_key, renamed_key)) + new_converter = WeightRenaming(original_key, renamed_key) + mapping = param_name_to_load.get(renamed_key, new_converter) source_pattern = original_key # 3. Handle dtype casting @@ -884,7 +880,7 @@ def convert_and_load_state_dict_in_model( _dtype = empty_param.dtype # usually correct when initializing # 4. Handle TP sharding or device_map placement - future_or_tensor = None + future = None if device_mesh: if matched_tp_pattern := tp_plan_alt.search(renamed_key): matched_tp_pattern = tp_plan_by_group_name[matched_tp_pattern.lastgroup] @@ -902,14 +898,24 @@ def convert_and_load_state_dict_in_model( _dtype, ) - if future_or_tensor is None: + if future is None: device_match = device_map_regex.match(renamed_key) param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu") - # If disk, we need to materialize on cpu first - param_device = "cpu" if param_device == "disk" else param_device - future_or_tensor = spawn_materialize(thread_pool, tensor, param_device, _dtype) - - mapping.add_tensor(renamed_key, original_key, source_pattern, future_or_tensor) + # If disk, we skip loading the weight entirely if we only rename + if param_device == "disk": + # Simply remove from missing keys, no need to load it + if renamed_key in disk_offload_index and isinstance(mapping, WeightRenaming): + missing_keys.discard(renamed_key) + # Need to be loaded and resaved + else: + future = spawn_materialize(thread_pool, tensor, "cpu", _dtype) + else: + future = spawn_materialize(thread_pool, tensor, param_device, _dtype) + + # In the case of offloading, can still be None as we skip loading the param entirely + if future is not None: + mapping.add_tensor(renamed_key, original_key, source_pattern, future_or_tensor) + param_name_to_load[renamed_key] = mapping elif source_pattern is not None: # add all target keys as unexpected mapping = pattern_to_converter[source_pattern] for k in mapping.target_patterns: @@ -939,8 +945,8 @@ def convert_and_load_state_dict_in_model( param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu") # Offloading support if param_device == "disk": - disk_offload_index = offload_and_maybe_resave_param( - target_name, param, missing_keys, disk_offload_folder, disk_offload_index, mapping + disk_offload_index = offload_and_resave_param( + target_name, param, missing_keys, disk_offload_folder, disk_offload_index ) else: set_param_for_module( diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index fb6a8380cd32..3e9954479376 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -64,6 +64,7 @@ from transformers.testing_utils import ( TOKEN, CaptureLogger, + CPUMemoryMonitor, LoggingLevel, TemporaryHubRepo, TestCasePlus, @@ -2263,6 +2264,25 @@ def tracking_init(self, *args, **kwargs): # Reverse monkey patch threading.Thread.__init__ = original_init + def test_offloading_does_not_use_more_cpu_memory(self): + """Test that when we must have weights offloaded to the disk, loading will be performed synchronously + and sequentially, i.e. we do not use more cpu memory than available. Avoids regresion after + https://github.com/huggingface/transformers/pull/42632 and""" + from transformers import Qwen3VLForConditionalGeneration + + # Small enough, non-gated model + model_name = "Qwen/Qwen3-VL-2B-Instruct" + # This will make sure we load params on only 2GB of cpu memory, and everything else is offloaded to disk (model is + # about 4GiB on fp16) + max_memory = {"cpu": "2GIB"} + monitor = CPUMemoryMonitor() + _ = Qwen3VLForConditionalGeneration.from_pretrained( + model_name, device_map="auto", max_memory=max_memory, dtype=torch.float16 + ) + peak = monitor.get_stats().peak_rss_gib + # We use 2.1 here instead of 2 to avoid being too flaky + self.assertTrue(peak < 2.1, "The process used more than 2GiB to load the model") + @slow @require_torch From 2a9dd74f8eecef44ff9df24d79c924fbfb3af97c Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 5 Dec 2025 18:56:55 +0100 Subject: [PATCH 2/2] fix --- src/transformers/core_model_loading.py | 4 ++-- tests/utils/test_modeling_utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index c69d954e77ab..1689bc8ea9c2 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -890,7 +890,7 @@ def convert_and_load_state_dict_in_model( device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone() ) shard_index = len(mapping.collected_tensors.get(original_key, [])) - future_or_tensor = spawn_tp_materialize( + future = spawn_tp_materialize( thread_pool, tensor, mapping.distributed_operation, @@ -914,7 +914,7 @@ def convert_and_load_state_dict_in_model( # In the case of offloading, can still be None as we skip loading the param entirely if future is not None: - mapping.add_tensor(renamed_key, original_key, source_pattern, future_or_tensor) + mapping.add_tensor(renamed_key, original_key, source_pattern, future) param_name_to_load[renamed_key] = mapping elif source_pattern is not None: # add all target keys as unexpected mapping = pattern_to_converter[source_pattern] diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 3e9954479376..d45beadd3880 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -2267,7 +2267,7 @@ def tracking_init(self, *args, **kwargs): def test_offloading_does_not_use_more_cpu_memory(self): """Test that when we must have weights offloaded to the disk, loading will be performed synchronously and sequentially, i.e. we do not use more cpu memory than available. Avoids regresion after - https://github.com/huggingface/transformers/pull/42632 and""" + https://github.com/huggingface/transformers/pull/42632 and https://github.com/huggingface/transformers/pull/42665""" from transformers import Qwen3VLForConditionalGeneration # Small enough, non-gated model