diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 65331ba8322c..1215f3ddf96b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -5074,14 +5074,17 @@ def create_accelerator_and_postprocess(self): self.is_tp_enabled = False if getattr(self.model, "tp_size", None) is not None and self.model.tp_size > 1: self.is_tp_enabled = True - if self.args.parallelism_config is not None: - if is_accelerate_available("1.10.1"): - if self.args.parallelism_config is not None: - from accelerate import ParallelismConfig - - args["parallelism_config"] = ParallelismConfig(tp_size=self.model.tp_size) + if is_accelerate_available("1.10.1"): + if self.args.parallelism_config is not None: + # Update tp_size in user-provided config instead of overwriting it + self.args.parallelism_config.tp_size = self.model.tp_size else: - raise ValueError("Requires accelerate>1.10.1 to use Tensor Parallelism.") + # Only create new config if user didn't provide one + from accelerate import ParallelismConfig + + args["parallelism_config"] = ParallelismConfig(tp_size=self.model.tp_size) + else: + raise ValueError("Requires accelerate>1.10.1 to use Tensor Parallelism.") if is_accelerate_available("1.2.0"): # it we don't have the correct version, we will rely on env var instead that were set in TrainingArguments