2020import re
2121from abc import abstractmethod
2222from collections import defaultdict
23- from collections .abc import MutableMapping , MutableSet
23+ from collections .abc import Callable , MutableMapping , MutableSet
2424from concurrent .futures import Future , ThreadPoolExecutor
2525from contextlib import contextmanager
2626from copy import deepcopy
@@ -327,10 +327,6 @@ def add_tensor(self, target_key: str, source_key: str, source_pattern: str, futu
327327 self .collected_tensors [source_pattern ].append (future )
328328 self .layer_targets [target_key ].add (source_key )
329329
330- def reset (self ) -> None :
331- """Clean-up the collected tensors to make sure we don't keep references to past tensors in memory."""
332- self .collected_tensors = defaultdict (list )
333-
334330 def rename_source_key (self , source_key : str ) -> tuple [str , str | None ]:
335331 """
336332 Return a tuple (renamed_key, source_pattern_producing_the_match).
@@ -375,6 +371,32 @@ def reverse_transform(self) -> WeightTransform:
375371
376372 return reverse_transform
377373
374+ def materialize_tensors (self ) -> dict [str , list [torch .Tensor ]]:
375+ """
376+ Materialize all the tensors that were saved in `self.collected_tensors`. This function removes them from the
377+ internal attribute to avoid keeping them in memory during the different `self.convert` operations, and return
378+ a new dictionary (otherwise we use more memory than needed during loading).
379+
380+ We basically have 3 cases here:
381+ - async loading (default): the tensors are Future instances that we need to wait for
382+ - sync loading: the tensors are Callable, we need to call the Callable to actually load them from disk
383+ - saving: the tensors are already torch.Tensor instances (the existing model weights)
384+ """
385+ collected_tensors = {}
386+ for key in set (self .collected_tensors .keys ()):
387+ # Remove from internal attribute
388+ tensors = self .collected_tensors .pop (key )
389+ # Async loading
390+ if isinstance (tensors [0 ], Future ):
391+ tensors = [future .result () for future in tensors ]
392+ # Sync loading
393+ elif callable (tensors [0 ]):
394+ tensors = [func () for func in tensors ]
395+ # Add them to the new dictionary
396+ collected_tensors [key ] = tensors
397+
398+ return collected_tensors
399+
378400
379401@dataclass (slots = True )
380402class WeightRenaming (WeightTransform ):
@@ -389,19 +411,17 @@ def convert(
389411 missing_keys : Optional [MutableSet [str ]] = None ,
390412 misc : Optional [MutableMapping [str , str ]] = None ,
391413 ):
392- # Collect the tensor if using threading
393- for pattern , futures in self .collected_tensors .items ():
394- self .collected_tensors [pattern ] = (
395- futures if isinstance (futures [0 ], torch .Tensor ) else [future .result () for future in futures ]
396- )
414+ # Collect the tensors here - we use a new dictionary to avoid keeping them in memory in the internal
415+ # attribute during the whole process
416+ collected_tensors = self .materialize_tensors ()
397417
398418 # Perform renaming op (for a simple WeightRenaming, `self.source_patterns` and `self.target_patterns` can
399419 # only be of length 1, and are actually the full key names - we also have only 1 single related tensor)
400420 target_key = self .target_patterns [0 ]
401- collected_tensors = {target_key : self . collected_tensors [self .source_patterns [0 ]]}
421+ collected_tensors = {target_key : collected_tensors [self .source_patterns [0 ]]}
402422
403423 if hf_quantizer is not None and self .quantization_operation is not None :
404- with log_to_misc (layer_name , misc , (self . collected_tensors , layer_name ), self .quantization_operation ):
424+ with log_to_misc (layer_name , misc , (len ( collected_tensors ) , layer_name ), self .quantization_operation ):
405425 collected_tensors = self .quantization_operation .convert (
406426 collected_tensors ,
407427 source_patterns = self .source_patterns ,
@@ -437,15 +457,12 @@ def convert(
437457 missing_keys : Optional [MutableSet [str ]] = None ,
438458 misc : Optional [MutableMapping [str , str ]] = None ,
439459 ):
440- # Collect all tensors if using threading
441- for pattern , futures in self .collected_tensors .items ():
442- self .collected_tensors [pattern ] = (
443- futures if isinstance (futures [0 ], torch .Tensor ) else [future .result () for future in futures ]
444- )
460+ # Collect the tensors here - we use a new dictionary to avoid keeping them in memory in the internal
461+ # attribute during the whole process
462+ collected_tensors = self .materialize_tensors ()
445463
446- collected_tensors = self .collected_tensors
447464 for op in self .operations :
448- with log_to_misc (layer_name , misc , (collected_tensors , layer_name ), op ):
465+ with log_to_misc (layer_name , misc , (len ( collected_tensors ) , layer_name ), op ):
449466 collected_tensors = op .convert (
450467 collected_tensors ,
451468 source_patterns = self .source_patterns ,
@@ -472,7 +489,7 @@ def convert(
472489 pass
473490
474491 if hf_quantizer is not None and self .quantization_operation is not None :
475- with log_to_misc (layer_name , misc , (collected_tensors , layer_name ), self .quantization_operation ):
492+ with log_to_misc (layer_name , misc , (len ( collected_tensors ) , layer_name ), self .quantization_operation ):
476493 collected_tensors = self .quantization_operation .convert (
477494 collected_tensors ,
478495 source_patterns = self .source_patterns ,
@@ -501,27 +518,36 @@ def _materialize_copy(tensor: torch.Tensor, device=None, dtype=None) -> torch.Te
501518
502519def spawn_materialize (
503520 thread_pool : ThreadPoolExecutor | None , tensor : torch .Tensor , device = None , dtype = None
504- ) -> Future | torch .Tensor :
505- """Materialize a tensor from file asynchronously if `thread_pool` is provided, or immediately otherwise."""
521+ ) -> Future | Callable :
522+ """Materialize a tensor from file asynchronously if `thread_pool` is provided, or return a Callable that will
523+ load the tensor synchronously when called."""
524+
525+ def _job ():
526+ return _materialize_copy (tensor , device , dtype )
527+
506528 if thread_pool is not None :
507- return thread_pool .submit (_materialize_copy , tensor , device , dtype )
529+ return thread_pool .submit (_job )
508530 else :
509- return _materialize_copy (tensor , device , dtype )
531+ # Return the Callable here, not the Tensor itself, so we actually delay loading to avoid saturating cpu
532+ # memory during Conversion
533+ return _job
510534
511535
512536def spawn_tp_materialize (
513537 thread_pool : ThreadPoolExecutor | None , tensor : torch .Tensor , sharding_method , tensor_idx , dtype = None
514- ) -> Future | torch . Tensor :
538+ ) -> Future | Callable :
515539 """Materialize and shard a tensor (according to the TP-plan) from file asynchronously if `thread_pool` is provided, or
516- immediately otherwise ."""
540+ return a Callable that will load the tensor synchronously when called ."""
517541
518542 def _job ():
519543 return sharding_method .shard_tensor (tensor , param_casting_dtype = dtype , tensor_idx = tensor_idx )[0 ]
520544
521545 if thread_pool is not None :
522546 return thread_pool .submit (_job )
523547 else :
524- return _job ()
548+ # Return the Callable here, not the Tensor itself, so we actually delay loading to avoid saturating cpu
549+ # memory during Conversion
550+ return _job
525551
526552
527553def dot_natural_key (s : str ):
@@ -557,10 +583,10 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) ->
557583
558584 op_name = _format_op_name (op )
559585 if isinstance (extras , tuple ) and len (extras ) == 2 :
560- values , target_keys = extras
586+ length , target_keys = extras
561587 descriptor = f"{ op_name } " if op_name else ""
562588 misc [first_target_key ] = (
563- f"{ e } \n Error: { descriptor } on tensors destined for { target_keys } . Ckpt contains: { len ( values ) } "
589+ f"{ e } \n Error: { descriptor } on tensors destined for { target_keys } . Ckpt contains: { length } "
564590 )
565591 elif isinstance (extras , str ):
566592 suffix = f" via { op_name } " if op_name else ""
@@ -796,11 +822,12 @@ def convert_and_load_state_dict_in_model(
796822 mismatch_keys = set ()
797823 unexpected_keys = set ()
798824
799- # We use threading by default, if not explicitly deactivated via env variable
800- if not is_env_variable_true ("HF_DEACTIVATE_ASYNC_LOAD" ):
801- thread_pool = ThreadPoolExecutor (max_workers = GLOBAL_WORKERS )
802- else :
825+ # We use threading by default, if not explicitly deactivated via env variable. If we have to offload,
826+ # we cannot use it either to control the memory as we are under memory constraints, so we need to be sequential
827+ if is_env_variable_true ("HF_DEACTIVATE_ASYNC_LOAD" ) or "disk" in device_map .values ():
803828 thread_pool = None
829+ else :
830+ thread_pool = ThreadPoolExecutor (max_workers = GLOBAL_WORKERS )
804831
805832 renamings = [entry for entry in weight_mapping if isinstance (entry , WeightRenaming )]
806833 converters = [entry for entry in weight_mapping if isinstance (entry , WeightConverter )]
@@ -928,8 +955,8 @@ def convert_and_load_state_dict_in_model(
928955 hf_quantizer ,
929956 )
930957
931- # Cleanup the tensors that were gathered internally in the mapping
932- mapping . reset ()
958+ # Cleanup all the tensors that were gathered before next iteration
959+ del realized_value
933960
934961 except SkipLayer :
935962 continue
0 commit comments