Skip to content

Commit d94f46c

Browse files
committed
Move checkpointing out of signal handlers
Signal handlers are a precarious environment, because they can be called at any time, for example in the middle of another function. So performing IO in there can lead to difficult to track down errors. This change moves the heavy lifting of requeuing into the normal life cycle of lightning loops.
1 parent 79ffe50 commit d94f46c

File tree

4 files changed

+168
-172
lines changed

4 files changed

+168
-172
lines changed

src/lightning/pytorch/trainer/connectors/signal_connector.py

Lines changed: 125 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,161 +1,182 @@
11
import logging
22
import os
3-
import re
43
import signal
4+
import subprocess
55
import threading
6-
from subprocess import call
76
from types import FrameType
87
from typing import Any, Callable, Union
98

10-
import torch
11-
import torch.distributed as dist
12-
139
import lightning.pytorch as pl
1410
from lightning.fabric.plugins.environments import SLURMEnvironment
1511
from lightning.fabric.utilities.imports import _IS_WINDOWS
16-
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_info
12+
from lightning.pytorch.callbacks import Callback
13+
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_debug, rank_zero_info, rank_zero_warn
1714

1815
# copied from signal.pyi
1916
_SIGNUM = Union[int, signal.Signals]
20-
_HANDLER = Union[Callable[[_SIGNUM, FrameType], Any], int, signal.Handlers, None]
17+
_HANDLER = Union[Callable[[_SIGNUM, FrameType | None], Any], int, signal.Handlers, None]
2118

2219
log = logging.getLogger(__name__)
2320

2421

2522
class _HandlersCompose:
26-
def __init__(self, signal_handlers: Union[list[_HANDLER], _HANDLER]) -> None:
27-
if not isinstance(signal_handlers, list):
28-
signal_handlers = [signal_handlers]
23+
def __init__(self, signal_handlers: list[_HANDLER]) -> None:
2924
self.signal_handlers = signal_handlers
3025

31-
def __call__(self, signum: _SIGNUM, frame: FrameType) -> None:
26+
def __call__(self, signum: _SIGNUM, frame: FrameType | None) -> None:
3227
for signal_handler in self.signal_handlers:
28+
if signal_handler is signal.SIG_DFL or signal_handler is signal.SIG_IGN:
29+
# If the handler is ignore, we skip it. Since there is no way for us to
30+
# trigger, the default signal handler, we ignore that one too
31+
continue
3332
if isinstance(signal_handler, int):
3433
signal_handler = signal.getsignal(signal_handler)
3534
if callable(signal_handler):
3635
signal_handler(signum, frame)
3736

3837

38+
class _SignalFlag:
39+
"""Becomes true when called as a signal handler."""
40+
41+
def __init__(self) -> None:
42+
self.state = False
43+
44+
def __call__(self, signum: _SIGNUM, _: FrameType | None) -> None:
45+
self.set()
46+
47+
def set(self) -> None:
48+
self.state = True
49+
50+
def check_and_reset(self) -> bool:
51+
"""Check the flag and reset it to false."""
52+
state = self.state
53+
self.state = False
54+
return state
55+
56+
57+
class _SignalHandlerCallback(Callback):
58+
def __init__(self, connector: "_SignalConnector") -> None:
59+
self.connector = connector
60+
61+
# Register same method for all callback methods
62+
on_methods = [f for f in dir(Callback) if f.startswith("on_") and callable(getattr(Callback, f))]
63+
for f in ["setup", "teardown"] + on_methods:
64+
setattr(self, f, self.notify_connector)
65+
66+
def notify_connector(self, *args: Any, **kwargs: Any) -> None:
67+
self.connector._process_signals()
68+
69+
3970
class _SignalConnector:
71+
"""Listen for process signals to, for example, requeue a job on a SLURM cluster.
72+
73+
The connector only stores the reception of signals in flags and then processes them at the next possible opportunity
74+
in the current loop. This minimizes the amount of code running in signal handlers, because file IO in signal
75+
handlers can crash the process. This also guarantees that we are not checkpointing in the middle of a backward pass
76+
or similar.
77+
78+
"""
79+
4080
def __init__(self, trainer: "pl.Trainer") -> None:
41-
self.received_sigterm = False
4281
self.trainer = trainer
82+
83+
# This flag is checked by the trainer and loops to exit gracefully
84+
self.received_sigterm = False
85+
86+
self.sigterm_flag = _SignalFlag()
87+
self.requeue_flag = _SignalFlag()
88+
4389
self._original_handlers: dict[_SIGNUM, _HANDLER] = {}
4490

91+
def register_callback(self) -> None:
92+
callback = _SignalHandlerCallback(self)
93+
self.trainer.callbacks = self.trainer.callbacks + [callback]
94+
4595
def register_signal_handlers(self) -> None:
46-
self.received_sigterm = False
47-
self._original_handlers = self._get_current_signal_handlers()
96+
if _IS_WINDOWS:
97+
# Windows seems to have signal incompatibilities
98+
rank_zero_info("Not registering signal handlers on Windows OS")
99+
return
48100

49-
sigusr_handlers: list[_HANDLER] = []
50-
sigterm_handlers: list[_HANDLER] = [self._sigterm_notifier_fn]
101+
if threading.current_thread() is not threading.main_thread():
102+
# Skip signal registration to allow training in non-main-threads
103+
rank_zero_debug("Not registering signal handlers outside of the main thread")
104+
return
105+
106+
self._register_signal_handler(signal.SIGTERM, self.sigterm_flag)
51107

52108
environment = self.trainer._accelerator_connector.cluster_environment
53109
if isinstance(environment, SLURMEnvironment) and environment.auto_requeue:
54-
log.info("SLURM auto-requeueing enabled. Setting signal handlers.")
55-
sigusr_handlers.append(self._slurm_sigusr_handler_fn)
56-
sigterm_handlers.append(self._sigterm_handler_fn)
57-
58-
# Windows seems to have signal incompatibilities
59-
if not _IS_WINDOWS:
60-
sigusr = environment.requeue_signal if isinstance(environment, SLURMEnvironment) else signal.SIGUSR1
61-
assert sigusr is not None
62-
if sigusr_handlers and not self._has_already_handler(sigusr):
63-
self._register_signal(sigusr, _HandlersCompose(sigusr_handlers))
64-
65-
# we have our own handler, but include existing ones too
66-
if self._has_already_handler(signal.SIGTERM):
67-
sigterm_handlers.append(signal.getsignal(signal.SIGTERM))
68-
self._register_signal(signal.SIGTERM, _HandlersCompose(sigterm_handlers))
69-
70-
def _slurm_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:
71-
rank_zero_info(f"Handling auto-requeue signal: {signum}")
72-
73-
# save logger to make sure we get all the metrics
110+
requeue_signal = environment.requeue_signal
111+
if requeue_signal is None:
112+
rank_zero_warn("Requested SLURM auto-requeueing, but signal is disabled. Could not set it up.")
113+
else:
114+
rank_zero_info(f"SLURM auto-requeueing enabled. Setting signal handlers for {requeue_signal.name}.")
115+
self._register_signal_handler(requeue_signal, self.requeue_flag)
116+
117+
def _process_signals(self) -> None:
118+
if self.requeue_flag.check_and_reset():
119+
rank_zero_info("Handling auto-requeue signal")
120+
self._slurm_requeue()
121+
122+
if self.sigterm_flag.check_and_reset():
123+
log.info(rank_prefixed_message("Received SIGTERM. Stopping.", self.trainer.local_rank))
124+
# Forward signal to subprocesses the first time it is received
125+
if not self.received_sigterm:
126+
launcher = self.trainer.strategy.launcher
127+
if launcher is not None:
128+
launcher.kill(signal.SIGTERM)
129+
self.received_sigterm = True
130+
131+
def _slurm_requeue(self) -> None:
132+
# Save logger to make sure we get all the metrics
74133
for logger in self.trainer.loggers:
75134
logger.finalize("finished")
76135

77136
hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(self.trainer.default_root_dir)
78137
self.trainer.save_checkpoint(hpc_save_path)
79138

80139
if self.trainer.is_global_zero:
81-
# find job id
82-
array_job_id = os.getenv("SLURM_ARRAY_JOB_ID")
83-
if array_job_id is not None:
84-
array_task_id = os.environ["SLURM_ARRAY_TASK_ID"]
85-
job_id = f"{array_job_id}_{array_task_id}"
86-
else:
87-
job_id = os.environ["SLURM_JOB_ID"]
88-
89-
assert re.match("[0-9_-]+", job_id)
140+
job_id = self._slurm_job_id()
90141
cmd = ["scontrol", "requeue", job_id]
91142

92-
# requeue job
93-
log.info(f"requeing job {job_id}...")
143+
# Requeue job
144+
log.info(f"Requeueing job {job_id}...")
94145
try:
95-
result = call(cmd)
146+
result = subprocess.run(cmd, capture_output=True, text=True)
96147
except FileNotFoundError:
97-
# This can occur if a subprocess call to `scontrol` is run outside a shell context
98-
# Re-attempt call (now with shell context). If any error is raised, propagate to user.
99-
# When running a shell command, it should be passed as a single string.
100-
result = call(" ".join(cmd), shell=True)
101-
102-
# print result text
103-
if result == 0:
104-
log.info(f"Requeued SLURM job: {job_id}")
105-
else:
106-
log.warning(f"Requeuing SLURM job {job_id} failed with error code {result}")
107-
108-
def _sigterm_notifier_fn(self, signum: _SIGNUM, _: FrameType) -> None:
109-
log.info(rank_prefixed_message(f"Received SIGTERM: {signum}", self.trainer.local_rank))
110-
if not self.received_sigterm:
111-
launcher = self.trainer.strategy.launcher
112-
if launcher is not None:
113-
launcher.kill(signum)
148+
# This can occur if a subprocess call to `scontrol` is run outside a shell context.
149+
# Try enlisting the help of the shell to resolve the `scontrol` binary.
150+
result = subprocess.run(" ".join(cmd), capture_output=True, text=True, shell=True)
114151

115-
# New broadcast logic
116-
if dist.is_available() and dist.is_initialized() and self.trainer.world_size > 1:
117-
sigterm_tensor = torch.tensor([1], device=self.trainer.strategy.root_device)
118-
dist.broadcast(sigterm_tensor, src=0)
119-
120-
self.received_sigterm = True
152+
# Print result text
153+
if result.returncode == 0:
154+
log.info(f"Requeued SLURM job {job_id}")
155+
else:
156+
log.warning(
157+
f"Requeueing SLURM job {job_id} failed with error code {result.returncode}: {result.stderr}"
158+
)
121159

122-
def _sigterm_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:
123-
log.info(f"Bypassing SIGTERM: {signum}")
160+
self.trainer.should_stop = True
124161

125162
def teardown(self) -> None:
126-
"""Restores the signals that were previously configured before :class:`_SignalConnector` replaced them."""
127-
for signum, handler in self._original_handlers.items():
128-
if handler is not None:
129-
self._register_signal(signum, handler)
130-
self._original_handlers = {}
131-
132-
@staticmethod
133-
def _get_current_signal_handlers() -> dict[_SIGNUM, _HANDLER]:
134-
"""Collects the currently assigned signal handlers."""
135-
valid_signals = _SignalConnector._valid_signals()
136-
if not _IS_WINDOWS:
137-
# SIGKILL and SIGSTOP are not allowed to be modified by the user
138-
valid_signals -= {signal.SIGKILL, signal.SIGSTOP}
139-
return {signum: signal.getsignal(signum) for signum in valid_signals}
140-
141-
@staticmethod
142-
def _valid_signals() -> set[signal.Signals]:
143-
"""Returns all valid signals supported on the current platform."""
144-
return signal.valid_signals()
145-
146-
@staticmethod
147-
def _has_already_handler(signum: _SIGNUM) -> bool:
148-
return signal.getsignal(signum) not in (None, signal.SIG_DFL)
149-
150-
@staticmethod
151-
def _register_signal(signum: _SIGNUM, handlers: _HANDLER) -> None:
163+
"""Restores the signals that :class:`_SignalConnector` overwrote."""
152164
if threading.current_thread() is threading.main_thread():
153-
signal.signal(signum, handlers) # type: ignore[arg-type]
165+
for signum, handler in self._original_handlers.items():
166+
signal.signal(signum, handler)
167+
self._original_handlers = {}
154168

155-
def __getstate__(self) -> dict:
156-
state = self.__dict__.copy()
157-
state["_original_handlers"] = {}
158-
return state
169+
def _register_signal_handler(self, signum: _SIGNUM, handler: _HANDLER) -> None:
170+
orig_handler = signal.getsignal(signum)
171+
self._original_handlers[signum] = orig_handler
172+
signal.signal(signum, _HandlersCompose([orig_handler, handler]))
173+
174+
def _slurm_job_id(self) -> str:
175+
array_job_id = os.environ.get("SLURM_ARRAY_JOB_ID")
176+
if array_job_id is not None:
177+
array_task_id = os.environ["SLURM_ARRAY_TASK_ID"]
178+
return f"{array_job_id}_{array_task_id}"
179+
return os.environ["SLURM_JOB_ID"]
159180

160181

161182
def _get_sigkill_signal() -> _SIGNUM:

src/lightning/pytorch/trainer/trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,9 @@ def __init__(
453453
max_time,
454454
)
455455

456+
# Handle signals as part of the train loop
457+
self._signal_connector.register_callback()
458+
456459
# init data flags
457460
self.check_val_every_n_epoch: Optional[int]
458461
self._data_connector.on_trainer_init(

tests/tests_pytorch/conftest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from lightning.fabric.utilities.distributed import _destroy_dist_connection, _distributed_is_initialized
3333
from lightning.fabric.utilities.imports import _IS_WINDOWS
3434
from lightning.pytorch.accelerators import XLAAccelerator
35-
from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector
3635
from tests_pytorch import _PATH_DATASETS
3736

3837

@@ -113,7 +112,7 @@ def restore_signal_handlers():
113112
This is a safety net for tests that don't run Trainer's teardown.
114113
115114
"""
116-
valid_signals = _SignalConnector._valid_signals()
115+
valid_signals = signal.valid_signals()
117116
if not _IS_WINDOWS:
118117
# SIGKILL and SIGSTOP are not allowed to be modified by the user
119118
valid_signals -= {signal.SIGKILL, signal.SIGSTOP}

0 commit comments

Comments
 (0)