Skip to content

Commit fda2d73

Browse files
efazalSunMarcgithub-actions[bot]
authored
feat(trainer): Just-in-time (JIT) asynchronous checkpointing using SIGTERM signals (#41723)
* Just-in-time (JIT) asynchronous checkpointing using SIGTERM signals and cuda streams. * Fix failing ci tests * Update JIT checkpoint code to remove CUDA streams and async checkpointing. Introduce sentinal file to identify incomplete checkpoints. Update trainer arg doc and tests. * Fix sentinel file save path to checkpoint folder, update checkpoint related envs with HF_ prefix. * Refactor JIT checkpoint logic: rename methods and variables for clarity, improve SIGTERM handling, and update related tests. * Remove support for environment variable overrides in `TrainingArguments` and update related documentation. * Apply style fixes --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent b083169 commit fda2d73

File tree

4 files changed

+570
-1
lines changed

4 files changed

+570
-1
lines changed

src/transformers/trainer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,16 @@ def __init__(
642642
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
643643
)
644644
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
645+
646+
# Add JIT checkpoint callback if enabled
647+
if self.args.enable_jit_checkpoint:
648+
from .trainer_jit_checkpoint import JITCheckpointCallback
649+
650+
jit_callback = JITCheckpointCallback()
651+
default_callbacks = default_callbacks + [jit_callback]
652+
# Set trainer reference for JIT callback after initialization
653+
jit_callback.set_trainer(self)
654+
645655
callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
646656
self.callback_handler = CallbackHandler(
647657
callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import os
2+
import signal
3+
import threading
4+
from typing import Optional
5+
6+
from .trainer_callback import TrainerCallback
7+
from .trainer_utils import PREFIX_CHECKPOINT_DIR
8+
from .utils import logging
9+
10+
11+
logger = logging.get_logger(__name__)
12+
13+
14+
class CheckpointManager:
15+
def __init__(self, trainer, kill_wait: int = 3):
16+
"""
17+
Initialize the CheckpointManager for Just-In-Time checkpoint handling.
18+
19+
Args:
20+
trainer: The Trainer instance that will be used to save checkpoints when SIGTERM is received.
21+
kill_wait (`int`, *optional*, defaults to 3): Grace period to distinguish between SIGTERM and SIGKILL.
22+
"""
23+
self.trainer = trainer
24+
self.is_checkpoint_requested = False
25+
self._original_sigterm_handler = None
26+
self.kill_wait = kill_wait
27+
28+
def setup_signal_handler(self):
29+
self._original_sigterm_handler = signal.signal(signal.SIGTERM, self._sigterm_handler)
30+
logger.info("JIT checkpoint signal handler registered for SIGTERM")
31+
32+
def _sigterm_handler(self, signum, frame):
33+
if self.is_checkpoint_requested:
34+
return
35+
36+
logger.info(f"SIGTERM received, will request JIT checkpoint after {self.kill_wait}s")
37+
threading.Timer(self.kill_wait, self._enable_checkpoint).start()
38+
39+
def _enable_checkpoint(self):
40+
logger.info("Kill wait period elapsed, requesting checkpoint")
41+
self.is_checkpoint_requested = True
42+
43+
def execute_jit_checkpoint(self):
44+
try:
45+
# Set checkpoint flag to False to avoid multiple checkpoints getting triggered by other callbacks
46+
self.is_checkpoint_requested = False
47+
48+
logger.info("Starting JIT checkpointing...")
49+
current_step = self.trainer.state.global_step
50+
logger.info(f"Saving JIT checkpoint at step {current_step}")
51+
52+
output_dir = self.trainer._get_output_dir(trial=None)
53+
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{current_step}"
54+
checkpoint_path = os.path.join(output_dir, checkpoint_folder)
55+
56+
# Create checkpoint directory
57+
os.makedirs(checkpoint_path, exist_ok=True)
58+
59+
# Create a sentinel file to indicate checkpointing is in progress
60+
sentinel_file = os.path.join(output_dir, checkpoint_folder, "checkpoint-is-incomplete.txt")
61+
with open(sentinel_file, "w") as f:
62+
f.write(f"Checkpoint started at step {current_step} and in progress...")
63+
logger.info(f"Created checkpoint progress sentinel marker file: {sentinel_file}")
64+
65+
# Invoke the trainer's checkpoint method directly
66+
self.trainer._save_checkpoint(self.trainer.model, trial=None)
67+
68+
# Remove sentinel file upon successful checkpointing
69+
if os.path.exists(sentinel_file):
70+
os.remove(sentinel_file)
71+
logger.info("Sentinel marker file removed")
72+
73+
logger.info("Immediate JIT checkpoint completed successfully")
74+
75+
except Exception as e:
76+
logger.error(f"Failed to save JIT checkpoint: {e}")
77+
raise
78+
79+
80+
class JITCheckpointCallback(TrainerCallback):
81+
"""
82+
Callback for Just-In-Time checkpointing on SIGTERM signals.
83+
84+
When SIGTERM is received, the checkpoint manager sets `is_checkpoint_requested=True`.
85+
The callbacks detect this flag and set `control.should_training_stop=True`, which signals
86+
the Trainer's training loop to exit gracefully after saving the checkpoint.
87+
"""
88+
89+
def __init__(self):
90+
self.trainer = None
91+
self.jit_manager: Optional[CheckpointManager] = None
92+
93+
def set_trainer(self, trainer):
94+
self.trainer = trainer
95+
if trainer.args.enable_jit_checkpoint:
96+
self.jit_manager = CheckpointManager(trainer=trainer)
97+
self.jit_manager.setup_signal_handler()
98+
logger.info("JIT checkpointing enabled")
99+
100+
def on_pre_optimizer_step(self, args, state, control, **kwargs):
101+
if self.jit_manager and self.jit_manager.is_checkpoint_requested:
102+
control.should_training_stop = True
103+
self.jit_manager.execute_jit_checkpoint()
104+
105+
def on_step_begin(self, args, state, control, **kwargs):
106+
if self.jit_manager and self.jit_manager.is_checkpoint_requested:
107+
control.should_training_stop = True
108+
self.jit_manager.execute_jit_checkpoint()
109+
110+
def on_step_end(self, args, state, control, **kwargs):
111+
if self.jit_manager and self.jit_manager.is_checkpoint_requested:
112+
control.should_save = False
113+
control.should_training_stop = True
114+
self.jit_manager.execute_jit_checkpoint()
115+
116+
def on_epoch_end(self, args, state, control, **kwargs):
117+
if self.jit_manager and self.jit_manager.is_checkpoint_requested:
118+
control.should_save = False
119+
control.should_training_stop = True
120+
self.jit_manager.execute_jit_checkpoint()
121+
122+
def on_train_end(self, args, state, control, **kwargs):
123+
# Restore original SIGTERM handler
124+
if self.jit_manager and self.jit_manager._original_sigterm_handler is not None:
125+
signal.signal(signal.SIGTERM, self.jit_manager._original_sigterm_handler)
126+
logger.info("Restored original SIGTERM handler after training completion")

src/transformers/training_args.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,17 @@ class TrainingArguments:
340340
`save_total_limit=5` and `load_best_model_at_end`, the four last checkpoints will always be retained
341341
alongside the best model. When `save_total_limit=1` and `load_best_model_at_end`, it is possible that two
342342
checkpoints are saved: the last one and the best one (if they are different).
343+
enable_jit_checkpoint (`bool`, *optional*, defaults to `False`):
344+
Whether to enable Just-In-Time (JIT) checkpointing on SIGTERM signal. When enabled, training will
345+
checkpoint upon receiving SIGTERM, allowing for graceful termination without losing
346+
progress. This is particularly useful for shared clusters with preemptible workloads (e.g., Kueue).
347+
**Important**: You must configure your orchestrator's graceful shutdown period to allow sufficient time
348+
for checkpoint completion. For Kubernetes, set `terminationGracePeriodSeconds` in your job definition
349+
(method varies by cloud-native trainer: Kubeflow, Ray, etc.). Note: the default is only 30 seconds,
350+
which is typically insufficient. For Slurm, use `--signal=USR1@<seconds>` in your sbatch script to send
351+
SIGTERM with adequate time before the job time limit. Calculate the required grace period as: longest
352+
possible iteration time + checkpoint saving time. For example, if an iteration takes 2 minutes and
353+
checkpoint saving takes 2 minutes, set at least 4 minutes (240 seconds) of grace time.
343354
save_safetensors (`bool`, *optional*, defaults to `True`):
344355
Use [safetensors](https://huggingface.co/docs/safetensors) saving and loading for state dicts instead of
345356
default `torch.load` and `torch.save`.
@@ -929,7 +940,23 @@ class TrainingArguments:
929940
" for `save_total_limit=5` and `load_best_model_at_end=True`, the four last checkpoints will always be"
930941
" retained alongside the best model. When `save_total_limit=1` and `load_best_model_at_end=True`,"
931942
" it is possible that two checkpoints are saved: the last one and the best one (if they are different)."
932-
" Default is unlimited checkpoints"
943+
" Default is unlimited checkpoints."
944+
)
945+
},
946+
)
947+
enable_jit_checkpoint: bool = field(
948+
default=False,
949+
metadata={
950+
"help": (
951+
"Whether to enable Just-In-Time (JIT) checkpointing on SIGTERM signal. "
952+
"When enabled, training will checkpoint upon receiving SIGTERM, "
953+
"allowing for graceful termination without losing progress. "
954+
"This is particularly useful for shared clusters with preemptible workloads (Kueue). "
955+
"IMPORTANT: You must configure your orchestrator's graceful shutdown period. "
956+
"Kubernetes: set terminationGracePeriodSeconds (default 30s is insufficient!) in your job definition. "
957+
"Slurm: use --signal=USR1@<seconds> in sbatch to send SIGTERM before time limit. "
958+
"Calculate required grace period as: iteration time + checkpoint saving time. "
959+
"Example: 2min iteration + 2min checkpoint = 240 seconds minimum."
933960
)
934961
},
935962
)

0 commit comments

Comments
 (0)