From d94f46cdfd3897c6e519ed102bf51d83d7930f2e Mon Sep 17 00:00:00 2001 From: Marten Lienen Date: Tue, 2 Dec 2025 19:40:47 +0100 Subject: [PATCH 1/2] Move checkpointing out of signal handlers Signal handlers are a precarious environment, because they can be called at any time, for example in the middle of another function. So performing IO in there can lead to difficult to track down errors. This change moves the heavy lifting of requeuing into the normal life cycle of lightning loops. --- .../trainer/connectors/signal_connector.py | 229 ++++++++++-------- src/lightning/pytorch/trainer/trainer.py | 3 + tests/tests_pytorch/conftest.py | 3 +- .../connectors/test_signal_connector.py | 105 +++----- 4 files changed, 168 insertions(+), 172 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/signal_connector.py b/src/lightning/pytorch/trainer/connectors/signal_connector.py index ece7e902c5f5f..3a570e11e6b44 100644 --- a/src/lightning/pytorch/trainer/connectors/signal_connector.py +++ b/src/lightning/pytorch/trainer/connectors/signal_connector.py @@ -1,76 +1,135 @@ import logging import os -import re import signal +import subprocess import threading -from subprocess import call from types import FrameType from typing import Any, Callable, Union -import torch -import torch.distributed as dist - import lightning.pytorch as pl from lightning.fabric.plugins.environments import SLURMEnvironment from lightning.fabric.utilities.imports import _IS_WINDOWS -from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_info +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_debug, rank_zero_info, rank_zero_warn # copied from signal.pyi _SIGNUM = Union[int, signal.Signals] -_HANDLER = Union[Callable[[_SIGNUM, FrameType], Any], int, signal.Handlers, None] +_HANDLER = Union[Callable[[_SIGNUM, FrameType | None], Any], int, signal.Handlers, None] log = logging.getLogger(__name__) class _HandlersCompose: - def __init__(self, signal_handlers: Union[list[_HANDLER], _HANDLER]) -> None: - if not isinstance(signal_handlers, list): - signal_handlers = [signal_handlers] + def __init__(self, signal_handlers: list[_HANDLER]) -> None: self.signal_handlers = signal_handlers - def __call__(self, signum: _SIGNUM, frame: FrameType) -> None: + def __call__(self, signum: _SIGNUM, frame: FrameType | None) -> None: for signal_handler in self.signal_handlers: + if signal_handler is signal.SIG_DFL or signal_handler is signal.SIG_IGN: + # If the handler is ignore, we skip it. Since there is no way for us to + # trigger, the default signal handler, we ignore that one too + continue if isinstance(signal_handler, int): signal_handler = signal.getsignal(signal_handler) if callable(signal_handler): signal_handler(signum, frame) +class _SignalFlag: + """Becomes true when called as a signal handler.""" + + def __init__(self) -> None: + self.state = False + + def __call__(self, signum: _SIGNUM, _: FrameType | None) -> None: + self.set() + + def set(self) -> None: + self.state = True + + def check_and_reset(self) -> bool: + """Check the flag and reset it to false.""" + state = self.state + self.state = False + return state + + +class _SignalHandlerCallback(Callback): + def __init__(self, connector: "_SignalConnector") -> None: + self.connector = connector + + # Register same method for all callback methods + on_methods = [f for f in dir(Callback) if f.startswith("on_") and callable(getattr(Callback, f))] + for f in ["setup", "teardown"] + on_methods: + setattr(self, f, self.notify_connector) + + def notify_connector(self, *args: Any, **kwargs: Any) -> None: + self.connector._process_signals() + + class _SignalConnector: + """Listen for process signals to, for example, requeue a job on a SLURM cluster. + + The connector only stores the reception of signals in flags and then processes them at the next possible opportunity + in the current loop. This minimizes the amount of code running in signal handlers, because file IO in signal + handlers can crash the process. This also guarantees that we are not checkpointing in the middle of a backward pass + or similar. + + """ + def __init__(self, trainer: "pl.Trainer") -> None: - self.received_sigterm = False self.trainer = trainer + + # This flag is checked by the trainer and loops to exit gracefully + self.received_sigterm = False + + self.sigterm_flag = _SignalFlag() + self.requeue_flag = _SignalFlag() + self._original_handlers: dict[_SIGNUM, _HANDLER] = {} + def register_callback(self) -> None: + callback = _SignalHandlerCallback(self) + self.trainer.callbacks = self.trainer.callbacks + [callback] + def register_signal_handlers(self) -> None: - self.received_sigterm = False - self._original_handlers = self._get_current_signal_handlers() + if _IS_WINDOWS: + # Windows seems to have signal incompatibilities + rank_zero_info("Not registering signal handlers on Windows OS") + return - sigusr_handlers: list[_HANDLER] = [] - sigterm_handlers: list[_HANDLER] = [self._sigterm_notifier_fn] + if threading.current_thread() is not threading.main_thread(): + # Skip signal registration to allow training in non-main-threads + rank_zero_debug("Not registering signal handlers outside of the main thread") + return + + self._register_signal_handler(signal.SIGTERM, self.sigterm_flag) environment = self.trainer._accelerator_connector.cluster_environment if isinstance(environment, SLURMEnvironment) and environment.auto_requeue: - log.info("SLURM auto-requeueing enabled. Setting signal handlers.") - sigusr_handlers.append(self._slurm_sigusr_handler_fn) - sigterm_handlers.append(self._sigterm_handler_fn) - - # Windows seems to have signal incompatibilities - if not _IS_WINDOWS: - sigusr = environment.requeue_signal if isinstance(environment, SLURMEnvironment) else signal.SIGUSR1 - assert sigusr is not None - if sigusr_handlers and not self._has_already_handler(sigusr): - self._register_signal(sigusr, _HandlersCompose(sigusr_handlers)) - - # we have our own handler, but include existing ones too - if self._has_already_handler(signal.SIGTERM): - sigterm_handlers.append(signal.getsignal(signal.SIGTERM)) - self._register_signal(signal.SIGTERM, _HandlersCompose(sigterm_handlers)) - - def _slurm_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None: - rank_zero_info(f"Handling auto-requeue signal: {signum}") - - # save logger to make sure we get all the metrics + requeue_signal = environment.requeue_signal + if requeue_signal is None: + rank_zero_warn("Requested SLURM auto-requeueing, but signal is disabled. Could not set it up.") + else: + rank_zero_info(f"SLURM auto-requeueing enabled. Setting signal handlers for {requeue_signal.name}.") + self._register_signal_handler(requeue_signal, self.requeue_flag) + + def _process_signals(self) -> None: + if self.requeue_flag.check_and_reset(): + rank_zero_info("Handling auto-requeue signal") + self._slurm_requeue() + + if self.sigterm_flag.check_and_reset(): + log.info(rank_prefixed_message("Received SIGTERM. Stopping.", self.trainer.local_rank)) + # Forward signal to subprocesses the first time it is received + if not self.received_sigterm: + launcher = self.trainer.strategy.launcher + if launcher is not None: + launcher.kill(signal.SIGTERM) + self.received_sigterm = True + + def _slurm_requeue(self) -> None: + # Save logger to make sure we get all the metrics for logger in self.trainer.loggers: logger.finalize("finished") @@ -78,84 +137,46 @@ def _slurm_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None: self.trainer.save_checkpoint(hpc_save_path) if self.trainer.is_global_zero: - # find job id - array_job_id = os.getenv("SLURM_ARRAY_JOB_ID") - if array_job_id is not None: - array_task_id = os.environ["SLURM_ARRAY_TASK_ID"] - job_id = f"{array_job_id}_{array_task_id}" - else: - job_id = os.environ["SLURM_JOB_ID"] - - assert re.match("[0-9_-]+", job_id) + job_id = self._slurm_job_id() cmd = ["scontrol", "requeue", job_id] - # requeue job - log.info(f"requeing job {job_id}...") + # Requeue job + log.info(f"Requeueing job {job_id}...") try: - result = call(cmd) + result = subprocess.run(cmd, capture_output=True, text=True) except FileNotFoundError: - # This can occur if a subprocess call to `scontrol` is run outside a shell context - # Re-attempt call (now with shell context). If any error is raised, propagate to user. - # When running a shell command, it should be passed as a single string. - result = call(" ".join(cmd), shell=True) - - # print result text - if result == 0: - log.info(f"Requeued SLURM job: {job_id}") - else: - log.warning(f"Requeuing SLURM job {job_id} failed with error code {result}") - - def _sigterm_notifier_fn(self, signum: _SIGNUM, _: FrameType) -> None: - log.info(rank_prefixed_message(f"Received SIGTERM: {signum}", self.trainer.local_rank)) - if not self.received_sigterm: - launcher = self.trainer.strategy.launcher - if launcher is not None: - launcher.kill(signum) + # This can occur if a subprocess call to `scontrol` is run outside a shell context. + # Try enlisting the help of the shell to resolve the `scontrol` binary. + result = subprocess.run(" ".join(cmd), capture_output=True, text=True, shell=True) - # New broadcast logic - if dist.is_available() and dist.is_initialized() and self.trainer.world_size > 1: - sigterm_tensor = torch.tensor([1], device=self.trainer.strategy.root_device) - dist.broadcast(sigterm_tensor, src=0) - - self.received_sigterm = True + # Print result text + if result.returncode == 0: + log.info(f"Requeued SLURM job {job_id}") + else: + log.warning( + f"Requeueing SLURM job {job_id} failed with error code {result.returncode}: {result.stderr}" + ) - def _sigterm_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None: - log.info(f"Bypassing SIGTERM: {signum}") + self.trainer.should_stop = True def teardown(self) -> None: - """Restores the signals that were previously configured before :class:`_SignalConnector` replaced them.""" - for signum, handler in self._original_handlers.items(): - if handler is not None: - self._register_signal(signum, handler) - self._original_handlers = {} - - @staticmethod - def _get_current_signal_handlers() -> dict[_SIGNUM, _HANDLER]: - """Collects the currently assigned signal handlers.""" - valid_signals = _SignalConnector._valid_signals() - if not _IS_WINDOWS: - # SIGKILL and SIGSTOP are not allowed to be modified by the user - valid_signals -= {signal.SIGKILL, signal.SIGSTOP} - return {signum: signal.getsignal(signum) for signum in valid_signals} - - @staticmethod - def _valid_signals() -> set[signal.Signals]: - """Returns all valid signals supported on the current platform.""" - return signal.valid_signals() - - @staticmethod - def _has_already_handler(signum: _SIGNUM) -> bool: - return signal.getsignal(signum) not in (None, signal.SIG_DFL) - - @staticmethod - def _register_signal(signum: _SIGNUM, handlers: _HANDLER) -> None: + """Restores the signals that :class:`_SignalConnector` overwrote.""" if threading.current_thread() is threading.main_thread(): - signal.signal(signum, handlers) # type: ignore[arg-type] + for signum, handler in self._original_handlers.items(): + signal.signal(signum, handler) + self._original_handlers = {} - def __getstate__(self) -> dict: - state = self.__dict__.copy() - state["_original_handlers"] = {} - return state + def _register_signal_handler(self, signum: _SIGNUM, handler: _HANDLER) -> None: + orig_handler = signal.getsignal(signum) + self._original_handlers[signum] = orig_handler + signal.signal(signum, _HandlersCompose([orig_handler, handler])) + + def _slurm_job_id(self) -> str: + array_job_id = os.environ.get("SLURM_ARRAY_JOB_ID") + if array_job_id is not None: + array_task_id = os.environ["SLURM_ARRAY_TASK_ID"] + return f"{array_job_id}_{array_task_id}" + return os.environ["SLURM_JOB_ID"] def _get_sigkill_signal() -> _SIGNUM: diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 6f947160ba9cb..949080ea27607 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -453,6 +453,9 @@ def __init__( max_time, ) + # Handle signals as part of the train loop + self._signal_connector.register_callback() + # init data flags self.check_val_every_n_epoch: Optional[int] self._data_connector.on_trainer_init( diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index 878298c6bfd94..befd8120a232d 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -32,7 +32,6 @@ from lightning.fabric.utilities.distributed import _destroy_dist_connection, _distributed_is_initialized from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.pytorch.accelerators import XLAAccelerator -from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector from tests_pytorch import _PATH_DATASETS @@ -113,7 +112,7 @@ def restore_signal_handlers(): This is a safety net for tests that don't run Trainer's teardown. """ - valid_signals = _SignalConnector._valid_signals() + valid_signals = signal.valid_signals() if not _IS_WINDOWS: # SIGKILL and SIGSTOP are not allowed to be modified by the user valid_signals -= {signal.SIGKILL, signal.SIGSTOP} diff --git a/tests/tests_pytorch/trainer/connectors/test_signal_connector.py b/tests/tests_pytorch/trainer/connectors/test_signal_connector.py index 83c5c2bb7e02b..0f3b39d7d5762 100644 --- a/tests/tests_pytorch/trainer/connectors/test_signal_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_signal_connector.py @@ -23,7 +23,7 @@ from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector +from lightning.pytorch.trainer.connectors.signal_connector import _HandlersCompose, _SignalConnector, _SignalFlag from lightning.pytorch.utilities.exceptions import SIGTERMException from tests_pytorch.helpers.runif import RunIf @@ -73,112 +73,85 @@ def training_step(self, batch, batch_idx): @RunIf(skip_windows=True) @pytest.mark.parametrize("auto_requeue", [True, False]) @pytest.mark.parametrize("requeue_signal", [signal.SIGUSR1, signal.SIGUSR2, signal.SIGHUP] if not _IS_WINDOWS else []) -def test_auto_requeue_custom_signal_flag(auto_requeue, requeue_signal): +def test_auto_requeue_signal_handlers(auto_requeue, requeue_signal): trainer = Trainer(plugins=[SLURMEnvironment(auto_requeue=auto_requeue, requeue_signal=requeue_signal)]) connector = _SignalConnector(trainer) connector.register_signal_handlers() - if auto_requeue: - sigterm_handlers = signal.getsignal(signal.SIGTERM).signal_handlers - assert len(sigterm_handlers) == 2 - assert sigterm_handlers[1].__qualname__ == "_SignalConnector._sigterm_handler_fn" + sigterm_handler = signal.getsignal(signal.SIGTERM) + assert isinstance(sigterm_handler, _HandlersCompose) + assert len(sigterm_handler.signal_handlers) == 2 + assert sigterm_handler.signal_handlers[0] is signal.SIG_DFL + assert isinstance(sigterm_handler.signal_handlers[1], _SignalFlag) - sigusr_handlers = signal.getsignal(requeue_signal).signal_handlers - assert len(sigusr_handlers) == 1 - assert sigusr_handlers[0].__qualname__ == "_SignalConnector._slurm_sigusr_handler_fn" + if auto_requeue: + sigusr_handler = signal.getsignal(requeue_signal) + assert isinstance(sigusr_handler, _HandlersCompose) + assert len(sigusr_handler.signal_handlers) == 2 + assert sigusr_handler.signal_handlers[0] is signal.SIG_DFL + assert isinstance(sigusr_handler.signal_handlers[1], _SignalFlag) else: - sigterm_handlers = signal.getsignal(signal.SIGTERM).signal_handlers - assert len(sigterm_handlers) == 1 - assert sigterm_handlers[0].__qualname__ == "_SignalConnector._sigterm_notifier_fn" - assert signal.getsignal(requeue_signal) is signal.SIG_DFL connector.teardown() @RunIf(skip_windows=True) -@mock.patch("lightning.pytorch.trainer.connectors.signal_connector.call") -@mock.patch("lightning.pytorch.trainer.Trainer.save_checkpoint", mock.MagicMock()) +@mock.patch("subprocess.run", return_value=Mock(returncode=0)) +@mock.patch("lightning.pytorch.trainer.Trainer.save_checkpoint") @mock.patch.dict(os.environ, {"SLURM_JOB_ID": "12345"}) -def test_auto_requeue_job(call_mock): - call_mock.return_value = 0 +def test_auto_requeue_job(ckpt_mock, run_mock): trainer = Trainer(plugins=[SLURMEnvironment()]) connector = _SignalConnector(trainer) - connector._slurm_sigusr_handler_fn(None, None) - call_mock.assert_called_once_with(["scontrol", "requeue", "12345"]) + connector.requeue_flag.set() + connector._process_signals() + + ckpt_mock.assert_called_once() + run_mock.assert_called_once() + assert run_mock.call_args[0][0] == ["scontrol", "requeue", "12345"] @RunIf(skip_windows=True) -@mock.patch("lightning.pytorch.trainer.connectors.signal_connector.call") -@mock.patch("lightning.pytorch.trainer.Trainer.save_checkpoint", mock.MagicMock()) +@mock.patch("subprocess.run", return_value=Mock(returncode=0)) +@mock.patch("lightning.pytorch.trainer.Trainer.save_checkpoint") @mock.patch.dict(os.environ, {"SLURM_JOB_ID": "12346", "SLURM_ARRAY_JOB_ID": "12345", "SLURM_ARRAY_TASK_ID": "2"}) -def test_auto_requeue_array_job(call_mock): - call_mock.return_value = 0 +def test_auto_requeue_array_job(ckpt_mock, run_mock): trainer = Trainer(plugins=[SLURMEnvironment()]) connector = _SignalConnector(trainer) - connector._slurm_sigusr_handler_fn(None, None) - call_mock.assert_called_once_with(["scontrol", "requeue", "12345_2"]) + connector.requeue_flag.set() + connector._process_signals() - -@RunIf(skip_windows=True) -@mock.patch("lightning.pytorch.trainer.connectors.signal_connector.call") -@mock.patch("lightning.pytorch.trainer.Trainer.save_checkpoint", mock.MagicMock()) -@mock.patch.dict(os.environ, {"SLURM_JOB_ID": "invalid"}) -def test_auto_requeue_invalid_job_id(call_mock): - call_mock.return_value = 0 - trainer = Trainer(plugins=[SLURMEnvironment()]) - connector = _SignalConnector(trainer) - with pytest.raises(AssertionError): - connector._slurm_sigusr_handler_fn(None, None) + ckpt_mock.assert_called_once() + run_mock.assert_called_once() + assert run_mock.call_args[0][0] == ["scontrol", "requeue", "12345_2"] def _registering_signals(): trainer = Trainer() trainer._signal_connector.register_signal_handlers() + trainer._signal_connector.teardown() @RunIf(skip_windows=True) -def test_signal_connector_in_thread(): +def test_no_signal_handling_in_non_main_thread(): with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: for future in concurrent.futures.as_completed([executor.submit(_registering_signals)]): assert future.exception() is None -def signal_handler(): - pass - - -class SignalHandlers: - def signal_handler(self): - pass - - -@pytest.mark.parametrize( - ("handler", "expected_return"), - [ - (None, False), - (signal.Handlers.SIG_IGN, True), - (signal.Handlers.SIG_DFL, False), - (signal_handler, True), - (SignalHandlers().signal_handler, True), - ], -) -def test_has_already_handler(handler, expected_return): - """Test that the SignalConnector detects whether a signal handler is already attached.""" - with mock.patch("lightning.pytorch.trainer.connectors.signal_connector.signal.getsignal", return_value=handler): - assert _SignalConnector._has_already_handler(signal.SIGTERM) is expected_return - - -def test_sigterm_notifier_fn(): +def test_sigterm_sets_flag_and_kills_subprocesses(): trainer = Mock() launcher = Mock() trainer.strategy.launcher = launcher connector = _SignalConnector(trainer) assert not connector.received_sigterm - connector._sigterm_notifier_fn(signal.SIGTERM, Mock()) - launcher.kill.assert_called_once_with(15) + connector.sigterm_flag.set() + connector._process_signals() + launcher.kill.assert_called_once_with(signal.SIGTERM) assert connector.received_sigterm + launcher.reset_mock() - connector._sigterm_notifier_fn(signal.SIGTERM, Mock()) + connector.sigterm_flag.set() + connector._process_signals() launcher.kill.assert_not_called() From f0c2c47c1862d29e6b07ca99c47178b6a6d1d5b2 Mon Sep 17 00:00:00 2001 From: Marten Lienen Date: Fri, 5 Dec 2025 17:01:13 +0100 Subject: [PATCH 2/2] Remove unused type-ignore comment --- src/lightning/pytorch/_graveyard/tpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/_graveyard/tpu.py b/src/lightning/pytorch/_graveyard/tpu.py index 34008e3ee556d..37fd7376ec958 100644 --- a/src/lightning/pytorch/_graveyard/tpu.py +++ b/src/lightning/pytorch/_graveyard/tpu.py @@ -122,4 +122,4 @@ def _patch_classes() -> None: _patch_sys_modules() _patch_classes() -SingleTPUStrategy.register_strategies(pl.strategies.StrategyRegistry) # type: ignore[has-type] +SingleTPUStrategy.register_strategies(pl.strategies.StrategyRegistry)