|
1 | 1 | import logging |
2 | 2 | import os |
3 | | -import re |
4 | 3 | import signal |
| 4 | +import subprocess |
5 | 5 | import threading |
6 | | -from subprocess import call |
7 | 6 | from types import FrameType |
8 | 7 | from typing import Any, Callable, Union |
9 | 8 |
|
10 | | -import torch |
11 | | -import torch.distributed as dist |
12 | | - |
13 | 9 | import lightning.pytorch as pl |
14 | 10 | from lightning.fabric.plugins.environments import SLURMEnvironment |
15 | 11 | from lightning.fabric.utilities.imports import _IS_WINDOWS |
16 | | -from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_info |
| 12 | +from lightning.pytorch.callbacks import Callback |
| 13 | +from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_debug, rank_zero_info, rank_zero_warn |
17 | 14 |
|
18 | 15 | # copied from signal.pyi |
19 | 16 | _SIGNUM = Union[int, signal.Signals] |
20 | | -_HANDLER = Union[Callable[[_SIGNUM, FrameType], Any], int, signal.Handlers, None] |
| 17 | +_HANDLER = Union[Callable[[_SIGNUM, FrameType | None], Any], int, signal.Handlers, None] |
21 | 18 |
|
22 | 19 | log = logging.getLogger(__name__) |
23 | 20 |
|
24 | 21 |
|
25 | 22 | class _HandlersCompose: |
26 | | - def __init__(self, signal_handlers: Union[list[_HANDLER], _HANDLER]) -> None: |
27 | | - if not isinstance(signal_handlers, list): |
28 | | - signal_handlers = [signal_handlers] |
| 23 | + def __init__(self, signal_handlers: list[_HANDLER]) -> None: |
29 | 24 | self.signal_handlers = signal_handlers |
30 | 25 |
|
31 | | - def __call__(self, signum: _SIGNUM, frame: FrameType) -> None: |
| 26 | + def __call__(self, signum: _SIGNUM, frame: FrameType | None) -> None: |
32 | 27 | for signal_handler in self.signal_handlers: |
| 28 | + if signal_handler is signal.SIG_DFL or signal_handler is signal.SIG_IGN: |
| 29 | + # If the handler is ignore, we skip it. Since there is no way for us to |
| 30 | + # trigger, the default signal handler, we ignore that one too |
| 31 | + continue |
33 | 32 | if isinstance(signal_handler, int): |
34 | 33 | signal_handler = signal.getsignal(signal_handler) |
35 | 34 | if callable(signal_handler): |
36 | 35 | signal_handler(signum, frame) |
37 | 36 |
|
38 | 37 |
|
| 38 | +class _SignalFlag: |
| 39 | + """Becomes true when called as a signal handler.""" |
| 40 | + |
| 41 | + def __init__(self) -> None: |
| 42 | + self.state = False |
| 43 | + |
| 44 | + def __call__(self, signum: _SIGNUM, _: FrameType | None) -> None: |
| 45 | + self.set() |
| 46 | + |
| 47 | + def set(self) -> None: |
| 48 | + self.state = True |
| 49 | + |
| 50 | + def check_and_reset(self) -> bool: |
| 51 | + """Check the flag and reset it to false.""" |
| 52 | + state = self.state |
| 53 | + self.state = False |
| 54 | + return state |
| 55 | + |
| 56 | + |
| 57 | +class _SignalHandlerCallback(Callback): |
| 58 | + def __init__(self, connector: "_SignalConnector") -> None: |
| 59 | + self.connector = connector |
| 60 | + |
| 61 | + # Register same method for all callback methods |
| 62 | + on_methods = [f for f in dir(Callback) if f.startswith("on_") and callable(getattr(Callback, f))] |
| 63 | + for f in ["setup", "teardown"] + on_methods: |
| 64 | + setattr(self, f, self.notify_connector) |
| 65 | + |
| 66 | + def notify_connector(self, *args: Any, **kwargs: Any) -> None: |
| 67 | + self.connector._process_signals() |
| 68 | + |
| 69 | + |
39 | 70 | class _SignalConnector: |
| 71 | + """Listen for process signals to, for example, requeue a job on a SLURM cluster. |
| 72 | +
|
| 73 | + The connector only stores the reception of signals in flags and then processes them at the next possible opportunity |
| 74 | + in the current loop. This minimizes the amount of code running in signal handlers, because file IO in signal |
| 75 | + handlers can crash the process. This also guarantees that we are not checkpointing in the middle of a backward pass |
| 76 | + or similar. |
| 77 | +
|
| 78 | + """ |
| 79 | + |
40 | 80 | def __init__(self, trainer: "pl.Trainer") -> None: |
41 | | - self.received_sigterm = False |
42 | 81 | self.trainer = trainer |
| 82 | + |
| 83 | + # This flag is checked by the trainer and loops to exit gracefully |
| 84 | + self.received_sigterm = False |
| 85 | + |
| 86 | + self.sigterm_flag = _SignalFlag() |
| 87 | + self.requeue_flag = _SignalFlag() |
| 88 | + |
43 | 89 | self._original_handlers: dict[_SIGNUM, _HANDLER] = {} |
44 | 90 |
|
| 91 | + def register_callback(self) -> None: |
| 92 | + callback = _SignalHandlerCallback(self) |
| 93 | + self.trainer.callbacks = self.trainer.callbacks + [callback] |
| 94 | + |
45 | 95 | def register_signal_handlers(self) -> None: |
46 | | - self.received_sigterm = False |
47 | | - self._original_handlers = self._get_current_signal_handlers() |
| 96 | + if _IS_WINDOWS: |
| 97 | + # Windows seems to have signal incompatibilities |
| 98 | + rank_zero_info("Not registering signal handlers on Windows OS") |
| 99 | + return |
48 | 100 |
|
49 | | - sigusr_handlers: list[_HANDLER] = [] |
50 | | - sigterm_handlers: list[_HANDLER] = [self._sigterm_notifier_fn] |
| 101 | + if threading.current_thread() is not threading.main_thread(): |
| 102 | + # Skip signal registration to allow training in non-main-threads |
| 103 | + rank_zero_debug("Not registering signal handlers outside of the main thread") |
| 104 | + return |
| 105 | + |
| 106 | + self._register_signal_handler(signal.SIGTERM, self.sigterm_flag) |
51 | 107 |
|
52 | 108 | environment = self.trainer._accelerator_connector.cluster_environment |
53 | 109 | if isinstance(environment, SLURMEnvironment) and environment.auto_requeue: |
54 | | - log.info("SLURM auto-requeueing enabled. Setting signal handlers.") |
55 | | - sigusr_handlers.append(self._slurm_sigusr_handler_fn) |
56 | | - sigterm_handlers.append(self._sigterm_handler_fn) |
57 | | - |
58 | | - # Windows seems to have signal incompatibilities |
59 | | - if not _IS_WINDOWS: |
60 | | - sigusr = environment.requeue_signal if isinstance(environment, SLURMEnvironment) else signal.SIGUSR1 |
61 | | - assert sigusr is not None |
62 | | - if sigusr_handlers and not self._has_already_handler(sigusr): |
63 | | - self._register_signal(sigusr, _HandlersCompose(sigusr_handlers)) |
64 | | - |
65 | | - # we have our own handler, but include existing ones too |
66 | | - if self._has_already_handler(signal.SIGTERM): |
67 | | - sigterm_handlers.append(signal.getsignal(signal.SIGTERM)) |
68 | | - self._register_signal(signal.SIGTERM, _HandlersCompose(sigterm_handlers)) |
69 | | - |
70 | | - def _slurm_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None: |
71 | | - rank_zero_info(f"Handling auto-requeue signal: {signum}") |
72 | | - |
73 | | - # save logger to make sure we get all the metrics |
| 110 | + requeue_signal = environment.requeue_signal |
| 111 | + if requeue_signal is None: |
| 112 | + rank_zero_warn("Requested SLURM auto-requeueing, but signal is disabled. Could not set it up.") |
| 113 | + else: |
| 114 | + rank_zero_info(f"SLURM auto-requeueing enabled. Setting signal handlers for {requeue_signal.name}.") |
| 115 | + self._register_signal_handler(requeue_signal, self.requeue_flag) |
| 116 | + |
| 117 | + def _process_signals(self) -> None: |
| 118 | + if self.requeue_flag.check_and_reset(): |
| 119 | + rank_zero_info("Handling auto-requeue signal") |
| 120 | + self._slurm_requeue() |
| 121 | + |
| 122 | + if self.sigterm_flag.check_and_reset(): |
| 123 | + log.info(rank_prefixed_message("Received SIGTERM. Stopping.", self.trainer.local_rank)) |
| 124 | + # Forward signal to subprocesses the first time it is received |
| 125 | + if not self.received_sigterm: |
| 126 | + launcher = self.trainer.strategy.launcher |
| 127 | + if launcher is not None: |
| 128 | + launcher.kill(signal.SIGTERM) |
| 129 | + self.received_sigterm = True |
| 130 | + |
| 131 | + def _slurm_requeue(self) -> None: |
| 132 | + # Save logger to make sure we get all the metrics |
74 | 133 | for logger in self.trainer.loggers: |
75 | 134 | logger.finalize("finished") |
76 | 135 |
|
77 | 136 | hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(self.trainer.default_root_dir) |
78 | 137 | self.trainer.save_checkpoint(hpc_save_path) |
79 | 138 |
|
80 | 139 | if self.trainer.is_global_zero: |
81 | | - # find job id |
82 | | - array_job_id = os.getenv("SLURM_ARRAY_JOB_ID") |
83 | | - if array_job_id is not None: |
84 | | - array_task_id = os.environ["SLURM_ARRAY_TASK_ID"] |
85 | | - job_id = f"{array_job_id}_{array_task_id}" |
86 | | - else: |
87 | | - job_id = os.environ["SLURM_JOB_ID"] |
88 | | - |
89 | | - assert re.match("[0-9_-]+", job_id) |
| 140 | + job_id = self._slurm_job_id() |
90 | 141 | cmd = ["scontrol", "requeue", job_id] |
91 | 142 |
|
92 | | - # requeue job |
93 | | - log.info(f"requeing job {job_id}...") |
| 143 | + # Requeue job |
| 144 | + log.info(f"Requeueing job {job_id}...") |
94 | 145 | try: |
95 | | - result = call(cmd) |
| 146 | + result = subprocess.run(cmd, capture_output=True, text=True) |
96 | 147 | except FileNotFoundError: |
97 | | - # This can occur if a subprocess call to `scontrol` is run outside a shell context |
98 | | - # Re-attempt call (now with shell context). If any error is raised, propagate to user. |
99 | | - # When running a shell command, it should be passed as a single string. |
100 | | - result = call(" ".join(cmd), shell=True) |
101 | | - |
102 | | - # print result text |
103 | | - if result == 0: |
104 | | - log.info(f"Requeued SLURM job: {job_id}") |
105 | | - else: |
106 | | - log.warning(f"Requeuing SLURM job {job_id} failed with error code {result}") |
107 | | - |
108 | | - def _sigterm_notifier_fn(self, signum: _SIGNUM, _: FrameType) -> None: |
109 | | - log.info(rank_prefixed_message(f"Received SIGTERM: {signum}", self.trainer.local_rank)) |
110 | | - if not self.received_sigterm: |
111 | | - launcher = self.trainer.strategy.launcher |
112 | | - if launcher is not None: |
113 | | - launcher.kill(signum) |
| 148 | + # This can occur if a subprocess call to `scontrol` is run outside a shell context. |
| 149 | + # Try enlisting the help of the shell to resolve the `scontrol` binary. |
| 150 | + result = subprocess.run(" ".join(cmd), capture_output=True, text=True, shell=True) |
114 | 151 |
|
115 | | - # New broadcast logic |
116 | | - if dist.is_available() and dist.is_initialized() and self.trainer.world_size > 1: |
117 | | - sigterm_tensor = torch.tensor([1], device=self.trainer.strategy.root_device) |
118 | | - dist.broadcast(sigterm_tensor, src=0) |
119 | | - |
120 | | - self.received_sigterm = True |
| 152 | + # Print result text |
| 153 | + if result.returncode == 0: |
| 154 | + log.info(f"Requeued SLURM job {job_id}") |
| 155 | + else: |
| 156 | + log.warning( |
| 157 | + f"Requeueing SLURM job {job_id} failed with error code {result.returncode}: {result.stderr}" |
| 158 | + ) |
121 | 159 |
|
122 | | - def _sigterm_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None: |
123 | | - log.info(f"Bypassing SIGTERM: {signum}") |
| 160 | + self.trainer.should_stop = True |
124 | 161 |
|
125 | 162 | def teardown(self) -> None: |
126 | | - """Restores the signals that were previously configured before :class:`_SignalConnector` replaced them.""" |
127 | | - for signum, handler in self._original_handlers.items(): |
128 | | - if handler is not None: |
129 | | - self._register_signal(signum, handler) |
130 | | - self._original_handlers = {} |
131 | | - |
132 | | - @staticmethod |
133 | | - def _get_current_signal_handlers() -> dict[_SIGNUM, _HANDLER]: |
134 | | - """Collects the currently assigned signal handlers.""" |
135 | | - valid_signals = _SignalConnector._valid_signals() |
136 | | - if not _IS_WINDOWS: |
137 | | - # SIGKILL and SIGSTOP are not allowed to be modified by the user |
138 | | - valid_signals -= {signal.SIGKILL, signal.SIGSTOP} |
139 | | - return {signum: signal.getsignal(signum) for signum in valid_signals} |
140 | | - |
141 | | - @staticmethod |
142 | | - def _valid_signals() -> set[signal.Signals]: |
143 | | - """Returns all valid signals supported on the current platform.""" |
144 | | - return signal.valid_signals() |
145 | | - |
146 | | - @staticmethod |
147 | | - def _has_already_handler(signum: _SIGNUM) -> bool: |
148 | | - return signal.getsignal(signum) not in (None, signal.SIG_DFL) |
149 | | - |
150 | | - @staticmethod |
151 | | - def _register_signal(signum: _SIGNUM, handlers: _HANDLER) -> None: |
| 163 | + """Restores the signals that :class:`_SignalConnector` overwrote.""" |
152 | 164 | if threading.current_thread() is threading.main_thread(): |
153 | | - signal.signal(signum, handlers) # type: ignore[arg-type] |
| 165 | + for signum, handler in self._original_handlers.items(): |
| 166 | + signal.signal(signum, handler) |
| 167 | + self._original_handlers = {} |
154 | 168 |
|
155 | | - def __getstate__(self) -> dict: |
156 | | - state = self.__dict__.copy() |
157 | | - state["_original_handlers"] = {} |
158 | | - return state |
| 169 | + def _register_signal_handler(self, signum: _SIGNUM, handler: _HANDLER) -> None: |
| 170 | + orig_handler = signal.getsignal(signum) |
| 171 | + self._original_handlers[signum] = orig_handler |
| 172 | + signal.signal(signum, _HandlersCompose([orig_handler, handler])) |
| 173 | + |
| 174 | + def _slurm_job_id(self) -> str: |
| 175 | + array_job_id = os.environ.get("SLURM_ARRAY_JOB_ID") |
| 176 | + if array_job_id is not None: |
| 177 | + array_task_id = os.environ["SLURM_ARRAY_TASK_ID"] |
| 178 | + return f"{array_job_id}_{array_task_id}" |
| 179 | + return os.environ["SLURM_JOB_ID"] |
159 | 180 |
|
160 | 181 |
|
161 | 182 | def _get_sigkill_signal() -> _SIGNUM: |
|
0 commit comments