diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 3558fe3cadc66..c7645c5ce5bcb 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -20,6 +20,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - +### Fixed + +- Sanitize profiler filenames when saving to avoid crashes due to invalid characters ([#21395](https://github.com/Lightning-AI/pytorch-lightning/pull/21395)) + ## [2.6.0] - 2025-11-28 diff --git a/src/lightning/pytorch/profilers/profiler.py b/src/lightning/pytorch/profilers/profiler.py index e8a8c60881062..ed26d9df791c9 100644 --- a/src/lightning/pytorch/profilers/profiler.py +++ b/src/lightning/pytorch/profilers/profiler.py @@ -15,6 +15,7 @@ import logging import os +import re from abc import ABC, abstractmethod from collections.abc import Generator from contextlib import contextmanager @@ -81,6 +82,7 @@ def _prepare_filename( action_name: Optional[str] = None, extension: str = ".txt", split_token: str = "-", # noqa: S107 + sanitize: bool = True, ) -> str: args = [] if self._stage is not None: @@ -91,7 +93,15 @@ def _prepare_filename( args.append(str(self._local_rank)) if action_name is not None: args.append(action_name) - return split_token.join(args) + extension + base = split_token.join(args) + if sanitize: + # Replace a set of path-unsafe characters across platforms with '_' + base = re.sub(r"[\\/:*?\"<>|\n\r\t]", "_", base) + base = re.sub(r"_+", "_", base) + base = base.strip() + if not base: + base = "profile" + return base + extension def _prepare_streams(self) -> None: if self._write_stream is not None: diff --git a/tests/tests_pytorch/profilers/test_profiler.py b/tests/tests_pytorch/profilers/test_profiler.py index 5423c8ab0e89d..d48250e63b986 100644 --- a/tests/tests_pytorch/profilers/test_profiler.py +++ b/tests/tests_pytorch/profilers/test_profiler.py @@ -322,6 +322,33 @@ def test_advanced_profiler_dump_states(tmp_path): assert len(data) > 0 +@pytest.mark.parametrize("char", ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\n", "\r", "\t"]) +def test_advanced_profiler_dump_states_sanitizes_filename(tmp_path, char): + """Profiler should sanitize action names to produce filesystem-safe .prof filenames. + + This guards against errors when callbacks or actions include path-unsafe characters (e.g., metric names with '/'). + + """ + profiler = AdvancedProfiler(dirpath=tmp_path, dump_stats=True) + action_name = f"before{char}after" + with profiler.profile(action_name): + pass + + profiler.describe() + + prof_files = [f for f in os.listdir(tmp_path) if f.endswith(".prof")] + assert len(prof_files) == 1 + prof_name = prof_files[0] + + # Ensure none of the path-unsafe characters are present in the produced filename + forbidden = ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\n", "\r", "\t"] + for bad in forbidden: + assert bad not in prof_name + + # File should be non-empty + assert (tmp_path / prof_name).read_bytes() + + def test_advanced_profiler_value_errors(advanced_profiler): """Ensure errors are raised where expected.""" action = "test"