diff --git a/src/lightning/fabric/strategies/launchers/multiprocessing.py b/src/lightning/fabric/strategies/launchers/multiprocessing.py index 3b3e180e63f41..13fe1a86c0898 100644 --- a/src/lightning/fabric/strategies/launchers/multiprocessing.py +++ b/src/lightning/fabric/strategies/launchers/multiprocessing.py @@ -195,17 +195,29 @@ def _check_bad_cuda_fork() -> None: Lightning users. """ - if not torch.cuda.is_initialized(): - return - - message = ( - "Lightning can't create new processes if CUDA is already initialized. Did you manually call" - " `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any" - " other way? Please remove any such calls, or change the selected strategy." - ) - if _IS_INTERACTIVE: - message += " You will have to restart the Python kernel." - raise RuntimeError(message) + # Use PyTorch's internal check for bad fork state, which is more accurate than just checking if CUDA + # is initialized. This allows passive CUDA initialization (e.g., from library imports or device queries) + # while still catching actual problematic cases where CUDA context was created before forking. + _is_in_bad_fork = getattr(torch.cuda, "_is_in_bad_fork", None) + if _is_in_bad_fork is not None and _is_in_bad_fork(): + message = ( + "Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, " + "you must use the 'spawn' start method or avoid CUDA initialization in the main process." + ) + if _IS_INTERACTIVE: + message += " You will have to restart the Python kernel." + raise RuntimeError(message) + + # Fallback to the old check if _is_in_bad_fork is not available (older PyTorch versions) + if _is_in_bad_fork is None and torch.cuda.is_initialized(): + message = ( + "Lightning can't create new processes if CUDA is already initialized. Did you manually call" + " `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any" + " other way? Please remove any such calls, or change the selected strategy." + ) + if _IS_INTERACTIVE: + message += " You will have to restart the Python kernel." + raise RuntimeError(message) def _disable_module_memory_sharing(data: Any) -> Any: