From 9ada964ce174baaae42f7006068e805040ae822a Mon Sep 17 00:00:00 2001 From: arrdel Date: Fri, 5 Dec 2025 21:01:54 -0500 Subject: [PATCH] Fix parallelism_config being overwritten in TP-only training Fixes #42661 When using TP with a user-provided parallelism_config, the Trainer was incorrectly overwriting the entire config object with a new ParallelismConfig(tp_size=model.tp_size), discarding all user-provided settings (dp_size, pp_size, cp_backend, etc.). Changes: - If user provides parallelism_config, update only the tp_size attribute - If no config is provided, create a new ParallelismConfig with tp_size - Removed redundant nested condition check - Fixed logic flow to check accelerate version first This ensures user-provided parallelism configurations are preserved during TP-only training. --- src/transformers/trainer.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 65331ba8322c..1215f3ddf96b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -5074,14 +5074,17 @@ def create_accelerator_and_postprocess(self): self.is_tp_enabled = False if getattr(self.model, "tp_size", None) is not None and self.model.tp_size > 1: self.is_tp_enabled = True - if self.args.parallelism_config is not None: - if is_accelerate_available("1.10.1"): - if self.args.parallelism_config is not None: - from accelerate import ParallelismConfig - - args["parallelism_config"] = ParallelismConfig(tp_size=self.model.tp_size) + if is_accelerate_available("1.10.1"): + if self.args.parallelism_config is not None: + # Update tp_size in user-provided config instead of overwriting it + self.args.parallelism_config.tp_size = self.model.tp_size else: - raise ValueError("Requires accelerate>1.10.1 to use Tensor Parallelism.") + # Only create new config if user didn't provide one + from accelerate import ParallelismConfig + + args["parallelism_config"] = ParallelismConfig(tp_size=self.model.tp_size) + else: + raise ValueError("Requires accelerate>1.10.1 to use Tensor Parallelism.") if is_accelerate_available("1.2.0"): # it we don't have the correct version, we will rely on env var instead that were set in TrainingArguments