Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,22 +644,17 @@ def set_param_for_module(
setattr(module_obj, param_name, param_value)


def offload_and_maybe_resave_param(
def offload_and_resave_param(
target_name: str,
param: torch.Tensor,
missing_keys: MutableSet[str],
disk_offload_folder: str,
disk_offload_index: dict,
applied_ops: WeightConverter | WeightRenaming,
) -> dict:
"""Takes care of correctly offloading `param`. If it's not already present in the `disk_offload_index`, or if any
WeightConverter operations have been applied, it will resave the new parameter. Otherwise, it will use the original
`disk_offload_index` for this given param."""
"""Takes care of correctly offloading `param`. It will resave the new parameter, and update the index."""
# We need to remove from missing keys
missing_keys.discard(target_name)
# If not already offloaded, or if we applied any special Operation except Renaming, we need to re-save
if target_name not in disk_offload_index or isinstance(applied_ops, WeightConverter):
disk_offload_index = offload_weight(param, target_name, disk_offload_folder, disk_offload_index)
disk_offload_index = offload_weight(param, target_name, disk_offload_folder, disk_offload_index)
return disk_offload_index


Expand Down Expand Up @@ -856,10 +851,11 @@ def convert_and_load_state_dict_in_model(
if source_pattern is not None:
new_converter = deepcopy(pattern_to_converter[source_pattern])
# each target key gets its own converter instance
mapping = param_name_to_load.setdefault(renamed_key, new_converter)
mapping = param_name_to_load.get(renamed_key, new_converter)
# Otherwise, only potential renaming
else:
mapping = param_name_to_load.setdefault(renamed_key, WeightRenaming(original_key, renamed_key))
new_converter = WeightRenaming(original_key, renamed_key)
mapping = param_name_to_load.get(renamed_key, new_converter)
source_pattern = original_key

# 3. Handle dtype casting
Expand All @@ -884,7 +880,7 @@ def convert_and_load_state_dict_in_model(
_dtype = empty_param.dtype # usually correct when initializing

# 4. Handle TP sharding or device_map placement
future_or_tensor = None
future = None
if device_mesh:
if matched_tp_pattern := tp_plan_alt.search(renamed_key):
matched_tp_pattern = tp_plan_by_group_name[matched_tp_pattern.lastgroup]
Expand All @@ -894,22 +890,32 @@ def convert_and_load_state_dict_in_model(
device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone()
)
shard_index = len(mapping.collected_tensors.get(original_key, []))
future_or_tensor = spawn_tp_materialize(
future = spawn_tp_materialize(
thread_pool,
tensor,
mapping.distributed_operation,
shard_index,
_dtype,
)

if future_or_tensor is None:
if future is None:
device_match = device_map_regex.match(renamed_key)
param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
# If disk, we need to materialize on cpu first
param_device = "cpu" if param_device == "disk" else param_device
future_or_tensor = spawn_materialize(thread_pool, tensor, param_device, _dtype)

mapping.add_tensor(renamed_key, original_key, source_pattern, future_or_tensor)
# If disk, we skip loading the weight entirely if we only rename
if param_device == "disk":
# Simply remove from missing keys, no need to load it
if renamed_key in disk_offload_index and isinstance(mapping, WeightRenaming):
missing_keys.discard(renamed_key)
# Need to be loaded and resaved
else:
future = spawn_materialize(thread_pool, tensor, "cpu", _dtype)
else:
future = spawn_materialize(thread_pool, tensor, param_device, _dtype)

# In the case of offloading, can still be None as we skip loading the param entirely
if future is not None:
mapping.add_tensor(renamed_key, original_key, source_pattern, future)
param_name_to_load[renamed_key] = mapping
elif source_pattern is not None: # add all target keys as unexpected
mapping = pattern_to_converter[source_pattern]
for k in mapping.target_patterns:
Expand Down Expand Up @@ -939,8 +945,8 @@ def convert_and_load_state_dict_in_model(
param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
# Offloading support
if param_device == "disk":
disk_offload_index = offload_and_maybe_resave_param(
target_name, param, missing_keys, disk_offload_folder, disk_offload_index, mapping
disk_offload_index = offload_and_resave_param(
target_name, param, missing_keys, disk_offload_folder, disk_offload_index
)
else:
set_param_for_module(
Expand Down
20 changes: 20 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from transformers.testing_utils import (
TOKEN,
CaptureLogger,
CPUMemoryMonitor,
LoggingLevel,
TemporaryHubRepo,
TestCasePlus,
Expand Down Expand Up @@ -2263,6 +2264,25 @@ def tracking_init(self, *args, **kwargs):
# Reverse monkey patch
threading.Thread.__init__ = original_init

def test_offloading_does_not_use_more_cpu_memory(self):
"""Test that when we must have weights offloaded to the disk, loading will be performed synchronously
and sequentially, i.e. we do not use more cpu memory than available. Avoids regresion after
https://github.com/huggingface/transformers/pull/42632 and https://github.com/huggingface/transformers/pull/42665"""
from transformers import Qwen3VLForConditionalGeneration

# Small enough, non-gated model
model_name = "Qwen/Qwen3-VL-2B-Instruct"
# This will make sure we load params on only 2GB of cpu memory, and everything else is offloaded to disk (model is
# about 4GiB on fp16)
max_memory = {"cpu": "2GIB"}
monitor = CPUMemoryMonitor()
_ = Qwen3VLForConditionalGeneration.from_pretrained(
model_name, device_map="auto", max_memory=max_memory, dtype=torch.float16
)
peak = monitor.get_stats().peak_rss_gib
# We use 2.1 here instead of 2 to avoid being too flaky
self.assertTrue(peak < 2.1, "The process used more than 2GiB to load the model")


@slow
@require_torch
Expand Down