Skip to content

Commit 1d86d00

Browse files
authored
[loading] Correctly load params during offloading & careful memory considerations (#42632)
* do not load everything in advance * fix * fix * fix * fix * fix memory leaks during conversion * oupsi * fix device_map * add doc * fix * doc * make it a method * doc * first shot at test * fix test * fix * revert test: cpu mem too hard to track correctly * fix
1 parent fccb049 commit 1d86d00

File tree

2 files changed

+71
-35
lines changed

2 files changed

+71
-35
lines changed

src/transformers/core_model_loading.py

Lines changed: 62 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import re
2121
from abc import abstractmethod
2222
from collections import defaultdict
23-
from collections.abc import MutableMapping, MutableSet
23+
from collections.abc import Callable, MutableMapping, MutableSet
2424
from concurrent.futures import Future, ThreadPoolExecutor
2525
from contextlib import contextmanager
2626
from copy import deepcopy
@@ -327,10 +327,6 @@ def add_tensor(self, target_key: str, source_key: str, source_pattern: str, futu
327327
self.collected_tensors[source_pattern].append(future)
328328
self.layer_targets[target_key].add(source_key)
329329

330-
def reset(self) -> None:
331-
"""Clean-up the collected tensors to make sure we don't keep references to past tensors in memory."""
332-
self.collected_tensors = defaultdict(list)
333-
334330
def rename_source_key(self, source_key: str) -> tuple[str, str | None]:
335331
"""
336332
Return a tuple (renamed_key, source_pattern_producing_the_match).
@@ -375,6 +371,32 @@ def reverse_transform(self) -> WeightTransform:
375371

376372
return reverse_transform
377373

374+
def materialize_tensors(self) -> dict[str, list[torch.Tensor]]:
375+
"""
376+
Materialize all the tensors that were saved in `self.collected_tensors`. This function removes them from the
377+
internal attribute to avoid keeping them in memory during the different `self.convert` operations, and return
378+
a new dictionary (otherwise we use more memory than needed during loading).
379+
380+
We basically have 3 cases here:
381+
- async loading (default): the tensors are Future instances that we need to wait for
382+
- sync loading: the tensors are Callable, we need to call the Callable to actually load them from disk
383+
- saving: the tensors are already torch.Tensor instances (the existing model weights)
384+
"""
385+
collected_tensors = {}
386+
for key in set(self.collected_tensors.keys()):
387+
# Remove from internal attribute
388+
tensors = self.collected_tensors.pop(key)
389+
# Async loading
390+
if isinstance(tensors[0], Future):
391+
tensors = [future.result() for future in tensors]
392+
# Sync loading
393+
elif callable(tensors[0]):
394+
tensors = [func() for func in tensors]
395+
# Add them to the new dictionary
396+
collected_tensors[key] = tensors
397+
398+
return collected_tensors
399+
378400

379401
@dataclass(slots=True)
380402
class WeightRenaming(WeightTransform):
@@ -389,19 +411,17 @@ def convert(
389411
missing_keys: Optional[MutableSet[str]] = None,
390412
misc: Optional[MutableMapping[str, str]] = None,
391413
):
392-
# Collect the tensor if using threading
393-
for pattern, futures in self.collected_tensors.items():
394-
self.collected_tensors[pattern] = (
395-
futures if isinstance(futures[0], torch.Tensor) else [future.result() for future in futures]
396-
)
414+
# Collect the tensors here - we use a new dictionary to avoid keeping them in memory in the internal
415+
# attribute during the whole process
416+
collected_tensors = self.materialize_tensors()
397417

398418
# Perform renaming op (for a simple WeightRenaming, `self.source_patterns` and `self.target_patterns` can
399419
# only be of length 1, and are actually the full key names - we also have only 1 single related tensor)
400420
target_key = self.target_patterns[0]
401-
collected_tensors = {target_key: self.collected_tensors[self.source_patterns[0]]}
421+
collected_tensors = {target_key: collected_tensors[self.source_patterns[0]]}
402422

403423
if hf_quantizer is not None and self.quantization_operation is not None:
404-
with log_to_misc(layer_name, misc, (self.collected_tensors, layer_name), self.quantization_operation):
424+
with log_to_misc(layer_name, misc, (len(collected_tensors), layer_name), self.quantization_operation):
405425
collected_tensors = self.quantization_operation.convert(
406426
collected_tensors,
407427
source_patterns=self.source_patterns,
@@ -437,15 +457,12 @@ def convert(
437457
missing_keys: Optional[MutableSet[str]] = None,
438458
misc: Optional[MutableMapping[str, str]] = None,
439459
):
440-
# Collect all tensors if using threading
441-
for pattern, futures in self.collected_tensors.items():
442-
self.collected_tensors[pattern] = (
443-
futures if isinstance(futures[0], torch.Tensor) else [future.result() for future in futures]
444-
)
460+
# Collect the tensors here - we use a new dictionary to avoid keeping them in memory in the internal
461+
# attribute during the whole process
462+
collected_tensors = self.materialize_tensors()
445463

446-
collected_tensors = self.collected_tensors
447464
for op in self.operations:
448-
with log_to_misc(layer_name, misc, (collected_tensors, layer_name), op):
465+
with log_to_misc(layer_name, misc, (len(collected_tensors), layer_name), op):
449466
collected_tensors = op.convert(
450467
collected_tensors,
451468
source_patterns=self.source_patterns,
@@ -472,7 +489,7 @@ def convert(
472489
pass
473490

474491
if hf_quantizer is not None and self.quantization_operation is not None:
475-
with log_to_misc(layer_name, misc, (collected_tensors, layer_name), self.quantization_operation):
492+
with log_to_misc(layer_name, misc, (len(collected_tensors), layer_name), self.quantization_operation):
476493
collected_tensors = self.quantization_operation.convert(
477494
collected_tensors,
478495
source_patterns=self.source_patterns,
@@ -501,27 +518,36 @@ def _materialize_copy(tensor: torch.Tensor, device=None, dtype=None) -> torch.Te
501518

502519
def spawn_materialize(
503520
thread_pool: ThreadPoolExecutor | None, tensor: torch.Tensor, device=None, dtype=None
504-
) -> Future | torch.Tensor:
505-
"""Materialize a tensor from file asynchronously if `thread_pool` is provided, or immediately otherwise."""
521+
) -> Future | Callable:
522+
"""Materialize a tensor from file asynchronously if `thread_pool` is provided, or return a Callable that will
523+
load the tensor synchronously when called."""
524+
525+
def _job():
526+
return _materialize_copy(tensor, device, dtype)
527+
506528
if thread_pool is not None:
507-
return thread_pool.submit(_materialize_copy, tensor, device, dtype)
529+
return thread_pool.submit(_job)
508530
else:
509-
return _materialize_copy(tensor, device, dtype)
531+
# Return the Callable here, not the Tensor itself, so we actually delay loading to avoid saturating cpu
532+
# memory during Conversion
533+
return _job
510534

511535

512536
def spawn_tp_materialize(
513537
thread_pool: ThreadPoolExecutor | None, tensor: torch.Tensor, sharding_method, tensor_idx, dtype=None
514-
) -> Future | torch.Tensor:
538+
) -> Future | Callable:
515539
"""Materialize and shard a tensor (according to the TP-plan) from file asynchronously if `thread_pool` is provided, or
516-
immediately otherwise."""
540+
return a Callable that will load the tensor synchronously when called."""
517541

518542
def _job():
519543
return sharding_method.shard_tensor(tensor, param_casting_dtype=dtype, tensor_idx=tensor_idx)[0]
520544

521545
if thread_pool is not None:
522546
return thread_pool.submit(_job)
523547
else:
524-
return _job()
548+
# Return the Callable here, not the Tensor itself, so we actually delay loading to avoid saturating cpu
549+
# memory during Conversion
550+
return _job
525551

526552

527553
def dot_natural_key(s: str):
@@ -557,10 +583,10 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) ->
557583

558584
op_name = _format_op_name(op)
559585
if isinstance(extras, tuple) and len(extras) == 2:
560-
values, target_keys = extras
586+
length, target_keys = extras
561587
descriptor = f"{op_name} " if op_name else ""
562588
misc[first_target_key] = (
563-
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values)}"
589+
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {length}"
564590
)
565591
elif isinstance(extras, str):
566592
suffix = f" via {op_name}" if op_name else ""
@@ -796,11 +822,12 @@ def convert_and_load_state_dict_in_model(
796822
mismatch_keys = set()
797823
unexpected_keys = set()
798824

799-
# We use threading by default, if not explicitly deactivated via env variable
800-
if not is_env_variable_true("HF_DEACTIVATE_ASYNC_LOAD"):
801-
thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS)
802-
else:
825+
# We use threading by default, if not explicitly deactivated via env variable. If we have to offload,
826+
# we cannot use it either to control the memory as we are under memory constraints, so we need to be sequential
827+
if is_env_variable_true("HF_DEACTIVATE_ASYNC_LOAD") or "disk" in device_map.values():
803828
thread_pool = None
829+
else:
830+
thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS)
804831

805832
renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
806833
converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)]
@@ -928,8 +955,8 @@ def convert_and_load_state_dict_in_model(
928955
hf_quantizer,
929956
)
930957

931-
# Cleanup the tensors that were gathered internally in the mapping
932-
mapping.reset()
958+
# Cleanup all the tensors that were gathered before next iteration
959+
del realized_value
933960

934961
except SkipLayer:
935962
continue

src/transformers/integrations/accelerate.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,15 @@ def _get_device_map(
392392
)
393393
else:
394394
inferred_max_memory = get_max_memory(max_memory)
395+
396+
# If the user does not provide `max_memory`, accelerate sets the WHOLE cpu available memory as available.
397+
# This is unwanted, as we don't want to set extremely tight bound and pressure for cpu if we are memory-constrained,
398+
# especially if the model uses WeightConverter (because there will be some uncontrollable cpu memory spikes during
399+
# the conversions before we resave the weights). In those cases, it's better to offload to disk a bit more
400+
# if we were in-between, as otherwise we blow-up cpu memory
401+
if max_memory is None:
402+
inferred_max_memory["cpu"] *= 0.90
403+
395404
if hf_quantizer is not None:
396405
inferred_max_memory = hf_quantizer.adjust_max_memory(inferred_max_memory)
397406

0 commit comments

Comments
 (0)