Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

-

### Fixed

- Sanitize profiler filenames when saving to avoid crashes due to invalid characters ([#21395](https://github.com/Lightning-AI/pytorch-lightning/pull/21395))


## [2.6.0] - 2025-11-28

Expand Down
12 changes: 11 additions & 1 deletion src/lightning/pytorch/profilers/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import logging
import os
import re
from abc import ABC, abstractmethod
from collections.abc import Generator
from contextlib import contextmanager
Expand Down Expand Up @@ -81,6 +82,7 @@ def _prepare_filename(
action_name: Optional[str] = None,
extension: str = ".txt",
split_token: str = "-", # noqa: S107
sanitize: bool = True,
) -> str:
args = []
if self._stage is not None:
Expand All @@ -91,7 +93,15 @@ def _prepare_filename(
args.append(str(self._local_rank))
if action_name is not None:
args.append(action_name)
return split_token.join(args) + extension
base = split_token.join(args)
if sanitize:
# Replace a set of path-unsafe characters across platforms with '_'
base = re.sub(r"[\\/:*?\"<>|\n\r\t]", "_", base)
base = re.sub(r"_+", "_", base)
base = base.strip()
if not base:
base = "profile"
return base + extension

def _prepare_streams(self) -> None:
if self._write_stream is not None:
Expand Down
27 changes: 27 additions & 0 deletions tests/tests_pytorch/profilers/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,33 @@ def test_advanced_profiler_dump_states(tmp_path):
assert len(data) > 0


@pytest.mark.parametrize("char", ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\n", "\r", "\t"])
def test_advanced_profiler_dump_states_sanitizes_filename(tmp_path, char):
"""Profiler should sanitize action names to produce filesystem-safe .prof filenames.

This guards against errors when callbacks or actions include path-unsafe characters (e.g., metric names with '/').

"""
profiler = AdvancedProfiler(dirpath=tmp_path, dump_stats=True)
action_name = f"before{char}after"
with profiler.profile(action_name):
pass

profiler.describe()

prof_files = [f for f in os.listdir(tmp_path) if f.endswith(".prof")]
assert len(prof_files) == 1
prof_name = prof_files[0]

# Ensure none of the path-unsafe characters are present in the produced filename
forbidden = ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\n", "\r", "\t"]
for bad in forbidden:
assert bad not in prof_name

# File should be non-empty
assert (tmp_path / prof_name).read_bytes()


def test_advanced_profiler_value_errors(advanced_profiler):
"""Ensure errors are raised where expected."""
action = "test"
Expand Down
Loading