From 1c20b3815067bdb328f378fb617e6ec0773f9b3e Mon Sep 17 00:00:00 2001 From: arrdel Date: Wed, 3 Dec 2025 15:30:42 -0500 Subject: [PATCH 1/2] Fix ddp_notebook CUDA fork check to allow passive initialization The previous implementation used torch.cuda.is_initialized() which returns True even when CUDA is passively initialized (e.g., during library imports or device availability checks). This caused false positives in environments like Kaggle notebooks where libraries may query CUDA without creating a context. This fix uses PyTorch's internal torch.cuda._is_in_bad_fork() function, which more accurately detects when we're in an actual bad fork state (i.e., CUDA was initialized with a context and then the process was forked). The change allows passive CUDA initialization while still catching genuine problematic cases. Falls back to the old check for older PyTorch versions that don't have _is_in_bad_fork. Fixes #21389 --- .../strategies/launchers/multiprocessing.py | 34 +++++++++++++------ 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/lightning/fabric/strategies/launchers/multiprocessing.py b/src/lightning/fabric/strategies/launchers/multiprocessing.py index 3b3e180e63f41..7d3bb1ad22544 100644 --- a/src/lightning/fabric/strategies/launchers/multiprocessing.py +++ b/src/lightning/fabric/strategies/launchers/multiprocessing.py @@ -195,17 +195,29 @@ def _check_bad_cuda_fork() -> None: Lightning users. """ - if not torch.cuda.is_initialized(): - return - - message = ( - "Lightning can't create new processes if CUDA is already initialized. Did you manually call" - " `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any" - " other way? Please remove any such calls, or change the selected strategy." - ) - if _IS_INTERACTIVE: - message += " You will have to restart the Python kernel." - raise RuntimeError(message) + # Use PyTorch's internal check for bad fork state, which is more accurate than just checking if CUDA + # is initialized. This allows passive CUDA initialization (e.g., from library imports or device queries) + # while still catching actual problematic cases where CUDA context was created before forking. + _is_in_bad_fork = getattr(torch.cuda, "_is_in_bad_fork", None) + if _is_in_bad_fork is not None and _is_in_bad_fork(): + message = ( + "Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, " + "you must use the 'spawn' start method or avoid CUDA initialization in the main process." + ) + if _IS_INTERACTIVE: + message += " You will have to restart the Python kernel." + raise RuntimeError(message) + + # Fallback to the old check if _is_in_bad_fork is not available (older PyTorch versions) + if _is_in_bad_fork is None and torch.cuda.is_initialized(): + message = ( + "Lightning can't create new processes if CUDA is already initialized. Did you manually call" + " `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any" + " other way? Please remove any such calls, or change the selected strategy." + ) + if _IS_INTERACTIVE: + message += " You will have to restart the Python kernel." + raise RuntimeError(message) def _disable_module_memory_sharing(data: Any) -> Any: From fc8a8ec716751ffbd0e8ec205fc593a504e1f3d4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 3 Dec 2025 20:47:43 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/strategies/launchers/multiprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/strategies/launchers/multiprocessing.py b/src/lightning/fabric/strategies/launchers/multiprocessing.py index 7d3bb1ad22544..13fe1a86c0898 100644 --- a/src/lightning/fabric/strategies/launchers/multiprocessing.py +++ b/src/lightning/fabric/strategies/launchers/multiprocessing.py @@ -207,7 +207,7 @@ def _check_bad_cuda_fork() -> None: if _IS_INTERACTIVE: message += " You will have to restart the Python kernel." raise RuntimeError(message) - + # Fallback to the old check if _is_in_bad_fork is not available (older PyTorch versions) if _is_in_bad_fork is None and torch.cuda.is_initialized(): message = (