diff --git a/src/lightning/pytorch/_graveyard/tpu.py b/src/lightning/pytorch/_graveyard/tpu.py index 34008e3ee556d..37fd7376ec958 100644 --- a/src/lightning/pytorch/_graveyard/tpu.py +++ b/src/lightning/pytorch/_graveyard/tpu.py @@ -122,4 +122,4 @@ def _patch_classes() -> None: _patch_sys_modules() _patch_classes() -SingleTPUStrategy.register_strategies(pl.strategies.StrategyRegistry) # type: ignore[has-type] +SingleTPUStrategy.register_strategies(pl.strategies.StrategyRegistry) diff --git a/src/lightning/pytorch/trainer/connectors/signal_connector.py b/src/lightning/pytorch/trainer/connectors/signal_connector.py index ece7e902c5f5f..3a570e11e6b44 100644 --- a/src/lightning/pytorch/trainer/connectors/signal_connector.py +++ b/src/lightning/pytorch/trainer/connectors/signal_connector.py @@ -1,76 +1,135 @@ import logging import os -import re import signal +import subprocess import threading -from subprocess import call from types import FrameType from typing import Any, Callable, Union -import torch -import torch.distributed as dist - import lightning.pytorch as pl from lightning.fabric.plugins.environments import SLURMEnvironment from lightning.fabric.utilities.imports import _IS_WINDOWS -from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_info +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_debug, rank_zero_info, rank_zero_warn # copied from signal.pyi _SIGNUM = Union[int, signal.Signals] -_HANDLER = Union[Callable[[_SIGNUM, FrameType], Any], int, signal.Handlers, None] +_HANDLER = Union[Callable[[_SIGNUM, FrameType | None], Any], int, signal.Handlers, None] log = logging.getLogger(__name__) class _HandlersCompose: - def __init__(self, signal_handlers: Union[list[_HANDLER], _HANDLER]) -> None: - if not isinstance(signal_handlers, list): - signal_handlers = [signal_handlers] + def __init__(self, signal_handlers: list[_HANDLER]) -> None: self.signal_handlers = signal_handlers - def __call__(self, signum: _SIGNUM, frame: FrameType) -> None: + def __call__(self, signum: _SIGNUM, frame: FrameType | None) -> None: for signal_handler in self.signal_handlers: + if signal_handler is signal.SIG_DFL or signal_handler is signal.SIG_IGN: + # If the handler is ignore, we skip it. Since there is no way for us to + # trigger, the default signal handler, we ignore that one too + continue if isinstance(signal_handler, int): signal_handler = signal.getsignal(signal_handler) if callable(signal_handler): signal_handler(signum, frame) +class _SignalFlag: + """Becomes true when called as a signal handler.""" + + def __init__(self) -> None: + self.state = False + + def __call__(self, signum: _SIGNUM, _: FrameType | None) -> None: + self.set() + + def set(self) -> None: + self.state = True + + def check_and_reset(self) -> bool: + """Check the flag and reset it to false.""" + state = self.state + self.state = False + return state + + +class _SignalHandlerCallback(Callback): + def __init__(self, connector: "_SignalConnector") -> None: + self.connector = connector + + # Register same method for all callback methods + on_methods = [f for f in dir(Callback) if f.startswith("on_") and callable(getattr(Callback, f))] + for f in ["setup", "teardown"] + on_methods: + setattr(self, f, self.notify_connector) + + def notify_connector(self, *args: Any, **kwargs: Any) -> None: + self.connector._process_signals() + + class _SignalConnector: + """Listen for process signals to, for example, requeue a job on a SLURM cluster. + + The connector only stores the reception of signals in flags and then processes them at the next possible opportunity + in the current loop. This minimizes the amount of code running in signal handlers, because file IO in signal + handlers can crash the process. This also guarantees that we are not checkpointing in the middle of a backward pass + or similar. + + """ + def __init__(self, trainer: "pl.Trainer") -> None: - self.received_sigterm = False self.trainer = trainer + + # This flag is checked by the trainer and loops to exit gracefully + self.received_sigterm = False + + self.sigterm_flag = _SignalFlag() + self.requeue_flag = _SignalFlag() + self._original_handlers: dict[_SIGNUM, _HANDLER] = {} + def register_callback(self) -> None: + callback = _SignalHandlerCallback(self) + self.trainer.callbacks = self.trainer.callbacks + [callback] + def register_signal_handlers(self) -> None: - self.received_sigterm = False - self._original_handlers = self._get_current_signal_handlers() + if _IS_WINDOWS: + # Windows seems to have signal incompatibilities + rank_zero_info("Not registering signal handlers on Windows OS") + return - sigusr_handlers: list[_HANDLER] = [] - sigterm_handlers: list[_HANDLER] = [self._sigterm_notifier_fn] + if threading.current_thread() is not threading.main_thread(): + # Skip signal registration to allow training in non-main-threads + rank_zero_debug("Not registering signal handlers outside of the main thread") + return + + self._register_signal_handler(signal.SIGTERM, self.sigterm_flag) environment = self.trainer._accelerator_connector.cluster_environment if isinstance(environment, SLURMEnvironment) and environment.auto_requeue: - log.info("SLURM auto-requeueing enabled. Setting signal handlers.") - sigusr_handlers.append(self._slurm_sigusr_handler_fn) - sigterm_handlers.append(self._sigterm_handler_fn) - - # Windows seems to have signal incompatibilities - if not _IS_WINDOWS: - sigusr = environment.requeue_signal if isinstance(environment, SLURMEnvironment) else signal.SIGUSR1 - assert sigusr is not None - if sigusr_handlers and not self._has_already_handler(sigusr): - self._register_signal(sigusr, _HandlersCompose(sigusr_handlers)) - - # we have our own handler, but include existing ones too - if self._has_already_handler(signal.SIGTERM): - sigterm_handlers.append(signal.getsignal(signal.SIGTERM)) - self._register_signal(signal.SIGTERM, _HandlersCompose(sigterm_handlers)) - - def _slurm_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None: - rank_zero_info(f"Handling auto-requeue signal: {signum}") - - # save logger to make sure we get all the metrics + requeue_signal = environment.requeue_signal + if requeue_signal is None: + rank_zero_warn("Requested SLURM auto-requeueing, but signal is disabled. Could not set it up.") + else: + rank_zero_info(f"SLURM auto-requeueing enabled. Setting signal handlers for {requeue_signal.name}.") + self._register_signal_handler(requeue_signal, self.requeue_flag) + + def _process_signals(self) -> None: + if self.requeue_flag.check_and_reset(): + rank_zero_info("Handling auto-requeue signal") + self._slurm_requeue() + + if self.sigterm_flag.check_and_reset(): + log.info(rank_prefixed_message("Received SIGTERM. Stopping.", self.trainer.local_rank)) + # Forward signal to subprocesses the first time it is received + if not self.received_sigterm: + launcher = self.trainer.strategy.launcher + if launcher is not None: + launcher.kill(signal.SIGTERM) + self.received_sigterm = True + + def _slurm_requeue(self) -> None: + # Save logger to make sure we get all the metrics for logger in self.trainer.loggers: logger.finalize("finished") @@ -78,84 +137,46 @@ def _slurm_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None: self.trainer.save_checkpoint(hpc_save_path) if self.trainer.is_global_zero: - # find job id - array_job_id = os.getenv("SLURM_ARRAY_JOB_ID") - if array_job_id is not None: - array_task_id = os.environ["SLURM_ARRAY_TASK_ID"] - job_id = f"{array_job_id}_{array_task_id}" - else: - job_id = os.environ["SLURM_JOB_ID"] - - assert re.match("[0-9_-]+", job_id) + job_id = self._slurm_job_id() cmd = ["scontrol", "requeue", job_id] - # requeue job - log.info(f"requeing job {job_id}...") + # Requeue job + log.info(f"Requeueing job {job_id}...") try: - result = call(cmd) + result = subprocess.run(cmd, capture_output=True, text=True) except FileNotFoundError: - # This can occur if a subprocess call to `scontrol` is run outside a shell context - # Re-attempt call (now with shell context). If any error is raised, propagate to user. - # When running a shell command, it should be passed as a single string. - result = call(" ".join(cmd), shell=True) - - # print result text - if result == 0: - log.info(f"Requeued SLURM job: {job_id}") - else: - log.warning(f"Requeuing SLURM job {job_id} failed with error code {result}") - - def _sigterm_notifier_fn(self, signum: _SIGNUM, _: FrameType) -> None: - log.info(rank_prefixed_message(f"Received SIGTERM: {signum}", self.trainer.local_rank)) - if not self.received_sigterm: - launcher = self.trainer.strategy.launcher - if launcher is not None: - launcher.kill(signum) + # This can occur if a subprocess call to `scontrol` is run outside a shell context. + # Try enlisting the help of the shell to resolve the `scontrol` binary. + result = subprocess.run(" ".join(cmd), capture_output=True, text=True, shell=True) - # New broadcast logic - if dist.is_available() and dist.is_initialized() and self.trainer.world_size > 1: - sigterm_tensor = torch.tensor([1], device=self.trainer.strategy.root_device) - dist.broadcast(sigterm_tensor, src=0) - - self.received_sigterm = True + # Print result text + if result.returncode == 0: + log.info(f"Requeued SLURM job {job_id}") + else: + log.warning( + f"Requeueing SLURM job {job_id} failed with error code {result.returncode}: {result.stderr}" + ) - def _sigterm_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None: - log.info(f"Bypassing SIGTERM: {signum}") + self.trainer.should_stop = True def teardown(self) -> None: - """Restores the signals that were previously configured before :class:`_SignalConnector` replaced them.""" - for signum, handler in self._original_handlers.items(): - if handler is not None: - self._register_signal(signum, handler) - self._original_handlers = {} - - @staticmethod - def _get_current_signal_handlers() -> dict[_SIGNUM, _HANDLER]: - """Collects the currently assigned signal handlers.""" - valid_signals = _SignalConnector._valid_signals() - if not _IS_WINDOWS: - # SIGKILL and SIGSTOP are not allowed to be modified by the user - valid_signals -= {signal.SIGKILL, signal.SIGSTOP} - return {signum: signal.getsignal(signum) for signum in valid_signals} - - @staticmethod - def _valid_signals() -> set[signal.Signals]: - """Returns all valid signals supported on the current platform.""" - return signal.valid_signals() - - @staticmethod - def _has_already_handler(signum: _SIGNUM) -> bool: - return signal.getsignal(signum) not in (None, signal.SIG_DFL) - - @staticmethod - def _register_signal(signum: _SIGNUM, handlers: _HANDLER) -> None: + """Restores the signals that :class:`_SignalConnector` overwrote.""" if threading.current_thread() is threading.main_thread(): - signal.signal(signum, handlers) # type: ignore[arg-type] + for signum, handler in self._original_handlers.items(): + signal.signal(signum, handler) + self._original_handlers = {} - def __getstate__(self) -> dict: - state = self.__dict__.copy() - state["_original_handlers"] = {} - return state + def _register_signal_handler(self, signum: _SIGNUM, handler: _HANDLER) -> None: + orig_handler = signal.getsignal(signum) + self._original_handlers[signum] = orig_handler + signal.signal(signum, _HandlersCompose([orig_handler, handler])) + + def _slurm_job_id(self) -> str: + array_job_id = os.environ.get("SLURM_ARRAY_JOB_ID") + if array_job_id is not None: + array_task_id = os.environ["SLURM_ARRAY_TASK_ID"] + return f"{array_job_id}_{array_task_id}" + return os.environ["SLURM_JOB_ID"] def _get_sigkill_signal() -> _SIGNUM: diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 6f947160ba9cb..949080ea27607 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -453,6 +453,9 @@ def __init__( max_time, ) + # Handle signals as part of the train loop + self._signal_connector.register_callback() + # init data flags self.check_val_every_n_epoch: Optional[int] self._data_connector.on_trainer_init( diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index 878298c6bfd94..befd8120a232d 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -32,7 +32,6 @@ from lightning.fabric.utilities.distributed import _destroy_dist_connection, _distributed_is_initialized from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.pytorch.accelerators import XLAAccelerator -from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector from tests_pytorch import _PATH_DATASETS @@ -113,7 +112,7 @@ def restore_signal_handlers(): This is a safety net for tests that don't run Trainer's teardown. """ - valid_signals = _SignalConnector._valid_signals() + valid_signals = signal.valid_signals() if not _IS_WINDOWS: # SIGKILL and SIGSTOP are not allowed to be modified by the user valid_signals -= {signal.SIGKILL, signal.SIGSTOP} diff --git a/tests/tests_pytorch/trainer/connectors/test_signal_connector.py b/tests/tests_pytorch/trainer/connectors/test_signal_connector.py index 83c5c2bb7e02b..0f3b39d7d5762 100644 --- a/tests/tests_pytorch/trainer/connectors/test_signal_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_signal_connector.py @@ -23,7 +23,7 @@ from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector +from lightning.pytorch.trainer.connectors.signal_connector import _HandlersCompose, _SignalConnector, _SignalFlag from lightning.pytorch.utilities.exceptions import SIGTERMException from tests_pytorch.helpers.runif import RunIf @@ -73,112 +73,85 @@ def training_step(self, batch, batch_idx): @RunIf(skip_windows=True) @pytest.mark.parametrize("auto_requeue", [True, False]) @pytest.mark.parametrize("requeue_signal", [signal.SIGUSR1, signal.SIGUSR2, signal.SIGHUP] if not _IS_WINDOWS else []) -def test_auto_requeue_custom_signal_flag(auto_requeue, requeue_signal): +def test_auto_requeue_signal_handlers(auto_requeue, requeue_signal): trainer = Trainer(plugins=[SLURMEnvironment(auto_requeue=auto_requeue, requeue_signal=requeue_signal)]) connector = _SignalConnector(trainer) connector.register_signal_handlers() - if auto_requeue: - sigterm_handlers = signal.getsignal(signal.SIGTERM).signal_handlers - assert len(sigterm_handlers) == 2 - assert sigterm_handlers[1].__qualname__ == "_SignalConnector._sigterm_handler_fn" + sigterm_handler = signal.getsignal(signal.SIGTERM) + assert isinstance(sigterm_handler, _HandlersCompose) + assert len(sigterm_handler.signal_handlers) == 2 + assert sigterm_handler.signal_handlers[0] is signal.SIG_DFL + assert isinstance(sigterm_handler.signal_handlers[1], _SignalFlag) - sigusr_handlers = signal.getsignal(requeue_signal).signal_handlers - assert len(sigusr_handlers) == 1 - assert sigusr_handlers[0].__qualname__ == "_SignalConnector._slurm_sigusr_handler_fn" + if auto_requeue: + sigusr_handler = signal.getsignal(requeue_signal) + assert isinstance(sigusr_handler, _HandlersCompose) + assert len(sigusr_handler.signal_handlers) == 2 + assert sigusr_handler.signal_handlers[0] is signal.SIG_DFL + assert isinstance(sigusr_handler.signal_handlers[1], _SignalFlag) else: - sigterm_handlers = signal.getsignal(signal.SIGTERM).signal_handlers - assert len(sigterm_handlers) == 1 - assert sigterm_handlers[0].__qualname__ == "_SignalConnector._sigterm_notifier_fn" - assert signal.getsignal(requeue_signal) is signal.SIG_DFL connector.teardown() @RunIf(skip_windows=True) -@mock.patch("lightning.pytorch.trainer.connectors.signal_connector.call") -@mock.patch("lightning.pytorch.trainer.Trainer.save_checkpoint", mock.MagicMock()) +@mock.patch("subprocess.run", return_value=Mock(returncode=0)) +@mock.patch("lightning.pytorch.trainer.Trainer.save_checkpoint") @mock.patch.dict(os.environ, {"SLURM_JOB_ID": "12345"}) -def test_auto_requeue_job(call_mock): - call_mock.return_value = 0 +def test_auto_requeue_job(ckpt_mock, run_mock): trainer = Trainer(plugins=[SLURMEnvironment()]) connector = _SignalConnector(trainer) - connector._slurm_sigusr_handler_fn(None, None) - call_mock.assert_called_once_with(["scontrol", "requeue", "12345"]) + connector.requeue_flag.set() + connector._process_signals() + + ckpt_mock.assert_called_once() + run_mock.assert_called_once() + assert run_mock.call_args[0][0] == ["scontrol", "requeue", "12345"] @RunIf(skip_windows=True) -@mock.patch("lightning.pytorch.trainer.connectors.signal_connector.call") -@mock.patch("lightning.pytorch.trainer.Trainer.save_checkpoint", mock.MagicMock()) +@mock.patch("subprocess.run", return_value=Mock(returncode=0)) +@mock.patch("lightning.pytorch.trainer.Trainer.save_checkpoint") @mock.patch.dict(os.environ, {"SLURM_JOB_ID": "12346", "SLURM_ARRAY_JOB_ID": "12345", "SLURM_ARRAY_TASK_ID": "2"}) -def test_auto_requeue_array_job(call_mock): - call_mock.return_value = 0 +def test_auto_requeue_array_job(ckpt_mock, run_mock): trainer = Trainer(plugins=[SLURMEnvironment()]) connector = _SignalConnector(trainer) - connector._slurm_sigusr_handler_fn(None, None) - call_mock.assert_called_once_with(["scontrol", "requeue", "12345_2"]) + connector.requeue_flag.set() + connector._process_signals() - -@RunIf(skip_windows=True) -@mock.patch("lightning.pytorch.trainer.connectors.signal_connector.call") -@mock.patch("lightning.pytorch.trainer.Trainer.save_checkpoint", mock.MagicMock()) -@mock.patch.dict(os.environ, {"SLURM_JOB_ID": "invalid"}) -def test_auto_requeue_invalid_job_id(call_mock): - call_mock.return_value = 0 - trainer = Trainer(plugins=[SLURMEnvironment()]) - connector = _SignalConnector(trainer) - with pytest.raises(AssertionError): - connector._slurm_sigusr_handler_fn(None, None) + ckpt_mock.assert_called_once() + run_mock.assert_called_once() + assert run_mock.call_args[0][0] == ["scontrol", "requeue", "12345_2"] def _registering_signals(): trainer = Trainer() trainer._signal_connector.register_signal_handlers() + trainer._signal_connector.teardown() @RunIf(skip_windows=True) -def test_signal_connector_in_thread(): +def test_no_signal_handling_in_non_main_thread(): with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: for future in concurrent.futures.as_completed([executor.submit(_registering_signals)]): assert future.exception() is None -def signal_handler(): - pass - - -class SignalHandlers: - def signal_handler(self): - pass - - -@pytest.mark.parametrize( - ("handler", "expected_return"), - [ - (None, False), - (signal.Handlers.SIG_IGN, True), - (signal.Handlers.SIG_DFL, False), - (signal_handler, True), - (SignalHandlers().signal_handler, True), - ], -) -def test_has_already_handler(handler, expected_return): - """Test that the SignalConnector detects whether a signal handler is already attached.""" - with mock.patch("lightning.pytorch.trainer.connectors.signal_connector.signal.getsignal", return_value=handler): - assert _SignalConnector._has_already_handler(signal.SIGTERM) is expected_return - - -def test_sigterm_notifier_fn(): +def test_sigterm_sets_flag_and_kills_subprocesses(): trainer = Mock() launcher = Mock() trainer.strategy.launcher = launcher connector = _SignalConnector(trainer) assert not connector.received_sigterm - connector._sigterm_notifier_fn(signal.SIGTERM, Mock()) - launcher.kill.assert_called_once_with(15) + connector.sigterm_flag.set() + connector._process_signals() + launcher.kill.assert_called_once_with(signal.SIGTERM) assert connector.received_sigterm + launcher.reset_mock() - connector._sigterm_notifier_fn(signal.SIGTERM, Mock()) + connector.sigterm_flag.set() + connector._process_signals() launcher.kill.assert_not_called()