Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/pytorch/_graveyard/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,4 @@ def _patch_classes() -> None:
_patch_sys_modules()
_patch_classes()

SingleTPUStrategy.register_strategies(pl.strategies.StrategyRegistry) # type: ignore[has-type]
SingleTPUStrategy.register_strategies(pl.strategies.StrategyRegistry)
229 changes: 125 additions & 104 deletions src/lightning/pytorch/trainer/connectors/signal_connector.py
Original file line number Diff line number Diff line change
@@ -1,161 +1,182 @@
import logging
import os
import re
import signal
import subprocess
import threading
from subprocess import call
from types import FrameType
from typing import Any, Callable, Union

import torch
import torch.distributed as dist

import lightning.pytorch as pl
from lightning.fabric.plugins.environments import SLURMEnvironment
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_info
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_debug, rank_zero_info, rank_zero_warn

# copied from signal.pyi
_SIGNUM = Union[int, signal.Signals]
_HANDLER = Union[Callable[[_SIGNUM, FrameType], Any], int, signal.Handlers, None]
_HANDLER = Union[Callable[[_SIGNUM, FrameType | None], Any], int, signal.Handlers, None]

log = logging.getLogger(__name__)


class _HandlersCompose:
def __init__(self, signal_handlers: Union[list[_HANDLER], _HANDLER]) -> None:
if not isinstance(signal_handlers, list):
signal_handlers = [signal_handlers]
def __init__(self, signal_handlers: list[_HANDLER]) -> None:
self.signal_handlers = signal_handlers

def __call__(self, signum: _SIGNUM, frame: FrameType) -> None:
def __call__(self, signum: _SIGNUM, frame: FrameType | None) -> None:
for signal_handler in self.signal_handlers:
if signal_handler is signal.SIG_DFL or signal_handler is signal.SIG_IGN:
# If the handler is ignore, we skip it. Since there is no way for us to
# trigger, the default signal handler, we ignore that one too
continue
if isinstance(signal_handler, int):
signal_handler = signal.getsignal(signal_handler)
if callable(signal_handler):
signal_handler(signum, frame)


class _SignalFlag:
"""Becomes true when called as a signal handler."""

def __init__(self) -> None:
self.state = False

def __call__(self, signum: _SIGNUM, _: FrameType | None) -> None:
self.set()

def set(self) -> None:
self.state = True

def check_and_reset(self) -> bool:
"""Check the flag and reset it to false."""
state = self.state
self.state = False
return state


class _SignalHandlerCallback(Callback):
def __init__(self, connector: "_SignalConnector") -> None:
self.connector = connector

# Register same method for all callback methods
on_methods = [f for f in dir(Callback) if f.startswith("on_") and callable(getattr(Callback, f))]
for f in ["setup", "teardown"] + on_methods:
setattr(self, f, self.notify_connector)

def notify_connector(self, *args: Any, **kwargs: Any) -> None:
self.connector._process_signals()


class _SignalConnector:
"""Listen for process signals to, for example, requeue a job on a SLURM cluster.

The connector only stores the reception of signals in flags and then processes them at the next possible opportunity
in the current loop. This minimizes the amount of code running in signal handlers, because file IO in signal
handlers can crash the process. This also guarantees that we are not checkpointing in the middle of a backward pass
or similar.

"""

def __init__(self, trainer: "pl.Trainer") -> None:
self.received_sigterm = False
self.trainer = trainer

# This flag is checked by the trainer and loops to exit gracefully
self.received_sigterm = False

self.sigterm_flag = _SignalFlag()
self.requeue_flag = _SignalFlag()

self._original_handlers: dict[_SIGNUM, _HANDLER] = {}

def register_callback(self) -> None:
callback = _SignalHandlerCallback(self)
self.trainer.callbacks = self.trainer.callbacks + [callback]

def register_signal_handlers(self) -> None:
self.received_sigterm = False
self._original_handlers = self._get_current_signal_handlers()
if _IS_WINDOWS:
# Windows seems to have signal incompatibilities
rank_zero_info("Not registering signal handlers on Windows OS")
return

sigusr_handlers: list[_HANDLER] = []
sigterm_handlers: list[_HANDLER] = [self._sigterm_notifier_fn]
if threading.current_thread() is not threading.main_thread():
# Skip signal registration to allow training in non-main-threads
rank_zero_debug("Not registering signal handlers outside of the main thread")
return

self._register_signal_handler(signal.SIGTERM, self.sigterm_flag)

environment = self.trainer._accelerator_connector.cluster_environment
if isinstance(environment, SLURMEnvironment) and environment.auto_requeue:
log.info("SLURM auto-requeueing enabled. Setting signal handlers.")
sigusr_handlers.append(self._slurm_sigusr_handler_fn)
sigterm_handlers.append(self._sigterm_handler_fn)

# Windows seems to have signal incompatibilities
if not _IS_WINDOWS:
sigusr = environment.requeue_signal if isinstance(environment, SLURMEnvironment) else signal.SIGUSR1
assert sigusr is not None
if sigusr_handlers and not self._has_already_handler(sigusr):
self._register_signal(sigusr, _HandlersCompose(sigusr_handlers))

# we have our own handler, but include existing ones too
if self._has_already_handler(signal.SIGTERM):
sigterm_handlers.append(signal.getsignal(signal.SIGTERM))
self._register_signal(signal.SIGTERM, _HandlersCompose(sigterm_handlers))

def _slurm_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:
rank_zero_info(f"Handling auto-requeue signal: {signum}")

# save logger to make sure we get all the metrics
requeue_signal = environment.requeue_signal
if requeue_signal is None:
rank_zero_warn("Requested SLURM auto-requeueing, but signal is disabled. Could not set it up.")
else:
rank_zero_info(f"SLURM auto-requeueing enabled. Setting signal handlers for {requeue_signal.name}.")
self._register_signal_handler(requeue_signal, self.requeue_flag)

def _process_signals(self) -> None:
if self.requeue_flag.check_and_reset():
rank_zero_info("Handling auto-requeue signal")
self._slurm_requeue()

if self.sigterm_flag.check_and_reset():
log.info(rank_prefixed_message("Received SIGTERM. Stopping.", self.trainer.local_rank))
# Forward signal to subprocesses the first time it is received
if not self.received_sigterm:
launcher = self.trainer.strategy.launcher
if launcher is not None:
launcher.kill(signal.SIGTERM)
self.received_sigterm = True

def _slurm_requeue(self) -> None:
# Save logger to make sure we get all the metrics
for logger in self.trainer.loggers:
logger.finalize("finished")

hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(self.trainer.default_root_dir)
self.trainer.save_checkpoint(hpc_save_path)

if self.trainer.is_global_zero:
# find job id
array_job_id = os.getenv("SLURM_ARRAY_JOB_ID")
if array_job_id is not None:
array_task_id = os.environ["SLURM_ARRAY_TASK_ID"]
job_id = f"{array_job_id}_{array_task_id}"
else:
job_id = os.environ["SLURM_JOB_ID"]

assert re.match("[0-9_-]+", job_id)
job_id = self._slurm_job_id()
cmd = ["scontrol", "requeue", job_id]

# requeue job
log.info(f"requeing job {job_id}...")
# Requeue job
log.info(f"Requeueing job {job_id}...")
try:
result = call(cmd)
result = subprocess.run(cmd, capture_output=True, text=True)
except FileNotFoundError:
# This can occur if a subprocess call to `scontrol` is run outside a shell context
# Re-attempt call (now with shell context). If any error is raised, propagate to user.
# When running a shell command, it should be passed as a single string.
result = call(" ".join(cmd), shell=True)

# print result text
if result == 0:
log.info(f"Requeued SLURM job: {job_id}")
else:
log.warning(f"Requeuing SLURM job {job_id} failed with error code {result}")

def _sigterm_notifier_fn(self, signum: _SIGNUM, _: FrameType) -> None:
log.info(rank_prefixed_message(f"Received SIGTERM: {signum}", self.trainer.local_rank))
if not self.received_sigterm:
launcher = self.trainer.strategy.launcher
if launcher is not None:
launcher.kill(signum)
# This can occur if a subprocess call to `scontrol` is run outside a shell context.
# Try enlisting the help of the shell to resolve the `scontrol` binary.
result = subprocess.run(" ".join(cmd), capture_output=True, text=True, shell=True)

# New broadcast logic
if dist.is_available() and dist.is_initialized() and self.trainer.world_size > 1:
sigterm_tensor = torch.tensor([1], device=self.trainer.strategy.root_device)
dist.broadcast(sigterm_tensor, src=0)

self.received_sigterm = True
# Print result text
if result.returncode == 0:
log.info(f"Requeued SLURM job {job_id}")
else:
log.warning(
f"Requeueing SLURM job {job_id} failed with error code {result.returncode}: {result.stderr}"
)

def _sigterm_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:
log.info(f"Bypassing SIGTERM: {signum}")
self.trainer.should_stop = True

def teardown(self) -> None:
"""Restores the signals that were previously configured before :class:`_SignalConnector` replaced them."""
for signum, handler in self._original_handlers.items():
if handler is not None:
self._register_signal(signum, handler)
self._original_handlers = {}

@staticmethod
def _get_current_signal_handlers() -> dict[_SIGNUM, _HANDLER]:
"""Collects the currently assigned signal handlers."""
valid_signals = _SignalConnector._valid_signals()
if not _IS_WINDOWS:
# SIGKILL and SIGSTOP are not allowed to be modified by the user
valid_signals -= {signal.SIGKILL, signal.SIGSTOP}
return {signum: signal.getsignal(signum) for signum in valid_signals}

@staticmethod
def _valid_signals() -> set[signal.Signals]:
"""Returns all valid signals supported on the current platform."""
return signal.valid_signals()

@staticmethod
def _has_already_handler(signum: _SIGNUM) -> bool:
return signal.getsignal(signum) not in (None, signal.SIG_DFL)

@staticmethod
def _register_signal(signum: _SIGNUM, handlers: _HANDLER) -> None:
"""Restores the signals that :class:`_SignalConnector` overwrote."""
if threading.current_thread() is threading.main_thread():
signal.signal(signum, handlers) # type: ignore[arg-type]
for signum, handler in self._original_handlers.items():
signal.signal(signum, handler)
self._original_handlers = {}

def __getstate__(self) -> dict:
state = self.__dict__.copy()
state["_original_handlers"] = {}
return state
def _register_signal_handler(self, signum: _SIGNUM, handler: _HANDLER) -> None:
orig_handler = signal.getsignal(signum)
self._original_handlers[signum] = orig_handler
signal.signal(signum, _HandlersCompose([orig_handler, handler]))

def _slurm_job_id(self) -> str:
array_job_id = os.environ.get("SLURM_ARRAY_JOB_ID")
if array_job_id is not None:
array_task_id = os.environ["SLURM_ARRAY_TASK_ID"]
return f"{array_job_id}_{array_task_id}"
return os.environ["SLURM_JOB_ID"]


def _get_sigkill_signal() -> _SIGNUM:
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,9 @@ def __init__(
max_time,
)

# Handle signals as part of the train loop
self._signal_connector.register_callback()

# init data flags
self.check_val_every_n_epoch: Optional[int]
self._data_connector.on_trainer_init(
Expand Down
3 changes: 1 addition & 2 deletions tests/tests_pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from lightning.fabric.utilities.distributed import _destroy_dist_connection, _distributed_is_initialized
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.pytorch.accelerators import XLAAccelerator
from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector
from tests_pytorch import _PATH_DATASETS


Expand Down Expand Up @@ -113,7 +112,7 @@ def restore_signal_handlers():
This is a safety net for tests that don't run Trainer's teardown.

"""
valid_signals = _SignalConnector._valid_signals()
valid_signals = signal.valid_signals()
if not _IS_WINDOWS:
# SIGKILL and SIGSTOP are not allowed to be modified by the user
valid_signals -= {signal.SIGKILL, signal.SIGSTOP}
Expand Down
Loading
Loading