From edc21c73d194b976a578271e46b593478c1466ec Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 29 Nov 2025 18:38:55 +0800 Subject: [PATCH 01/59] chore(pre-commit): update pre-commit hooks --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cedc15ae..fd33e89f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,7 +38,7 @@ repos: hooks: - id: cpplint - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.6 + rev: v0.14.7 hooks: - id: ruff-check args: [--fix, --exit-non-zero-on-fix] @@ -50,7 +50,7 @@ repos: - id: codespell additional_dependencies: [".[toml]"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.18.2 + rev: v1.19.0 hooks: - id: mypy exclude: | From d5a1cc40280500af0140904ed75e36016f07a816 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 8 Oct 2025 23:56:29 +0800 Subject: [PATCH 02/59] test: reorgainize test files --- tests/concurrent/test_subinterpreters.py | 14 ++++++++++++++ .../test_threading.py} | 8 ++++---- 2 files changed, 18 insertions(+), 4 deletions(-) create mode 100644 tests/concurrent/test_subinterpreters.py rename tests/{test_concurrent.py => concurrent/test_threading.py} (97%) diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py new file mode 100644 index 00000000..59af7b0d --- /dev/null +++ b/tests/concurrent/test_subinterpreters.py @@ -0,0 +1,14 @@ +# Copyright 2022-2025 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/tests/test_concurrent.py b/tests/concurrent/test_threading.py similarity index 97% rename from tests/test_concurrent.py rename to tests/concurrent/test_threading.py index 1a74a5af..f6de0fae 100644 --- a/tests/test_concurrent.py +++ b/tests/concurrent/test_threading.py @@ -55,8 +55,8 @@ atexit.register(EXECUTOR.shutdown) -def concurrent_run(func): - futures = [EXECUTOR.submit(func) for _ in range(NUM_FUTURES)] +def concurrent_run(func, /, *args, **kwargs): + futures = [EXECUTOR.submit(func, *args, **kwargs) for _ in range(NUM_FUTURES)] future2index = {future: i for i, future in enumerate(futures)} completed_futures = sorted(as_completed(futures), key=future2index.get) first_exception = next(filter(None, (future.exception() for future in completed_futures)), None) @@ -92,7 +92,7 @@ def test_fn(): for result in concurrent_run(test_fn): assert result == expected - for result in concurrent_run(lambda: optree.tree_unflatten(treespec, leaves)): + for result in concurrent_run(optree.tree_unflatten, treespec, leaves): assert result == tree @@ -353,7 +353,7 @@ def test_tree_iter_thread_safe( namespace=namespace, ) - results = concurrent_run(lambda: list(it)) + results = concurrent_run(list, it) for seq in results: assert sorted(seq) == seq assert sorted(itertools.chain.from_iterable(results)) == list(range(num_leaves)) From 8bd5fde8b6255c9b02532e4f0ab1d68848811ce2 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 9 Oct 2025 00:46:35 +0800 Subject: [PATCH 03/59] test: add subinterpreters tests --- tests/concurrent/test_subinterpreters.py | 39 ++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py index 59af7b0d..df3ff8eb 100644 --- a/tests/concurrent/test_subinterpreters.py +++ b/tests/concurrent/test_subinterpreters.py @@ -12,3 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + +import atexit +import sys + +import pytest + +from helpers import ( + PYPY, + WASM, + Py_DEBUG, + Py_GIL_DISABLED, +) + + +if PYPY or WASM or sys.version_info < (3, 14): + pytest.skip('Test for CPython 3.14+ only', allow_module_level=True) + + +from concurrent.futures import InterpreterPoolExecutor + + +if Py_GIL_DISABLED and not Py_DEBUG: + NUM_WORKERS = 32 + NUM_FUTURES = 128 +else: + NUM_WORKERS = 4 + NUM_FUTURES = 16 + + +EXECUTOR = InterpreterPoolExecutor(max_workers=NUM_WORKERS) +atexit.register(EXECUTOR.shutdown) + + +def run(func, /, *args, **kwargs): + future = EXECUTOR.submit(func, *args, **kwargs) + exception = future.exception() + if exception is not None: + raise exception + return future.result() From 23abd08f5f395cb1a3d8b9e467df21af86af68c8 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 29 Nov 2025 19:11:38 +0800 Subject: [PATCH 04/59] chore: reorder dependencies --- .github/workflows/build.yml | 4 +++- docs/requirements.txt | 4 ++-- pyproject.toml | 13 +++++++------ src/optree.cpp | 18 +++++++----------- src/treespec/serialization.cpp | 2 +- tests/concurrent/test_subinterpreters.py | 15 ++++++++++++++- tests/concurrent/test_threading.py | 10 +++++++++- tests/requirements.txt | 17 ++++++++++------- 8 files changed, 53 insertions(+), 30 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index bb85665d..89605871 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -42,9 +42,11 @@ concurrency: cancel-in-progress: ${{ github.event_name == 'pull_request' }} env: + OPTREE_CXX_WERROR: "OFF" + _GLIBCXX_USE_CXX11_ABI: "1" + PYTHONUNBUFFERED: "1" PYTHON_TAG: "py3" # to be updated PYTHON_VERSION: "3" # to be updated - _GLIBCXX_USE_CXX11_ABI: "1" COLUMNS: "100" FORCE_COLOR: "1" CLICOLOR_FORCE: "1" diff --git a/docs/requirements.txt b/docs/requirements.txt index 55e1da0a..85a0ae49 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -3,14 +3,14 @@ sphinx sphinx-autoapi sphinx-autobuild +sphinx-autodoc-typehints sphinx-copybutton sphinx-rtd-theme sphinxcontrib-bibtex -sphinx-autodoc-typehints docutils --extra-index-url https://download.pytorch.org/whl/cpu -jax[cpu] >= 0.4.6 +jax[cpu] numpy torch diff --git a/pyproject.toml b/pyproject.toml index e2b91265..61fd419f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,19 +58,20 @@ jax = ["jax"] numpy = ["numpy"] torch = ["torch"] lint = [ - "ruff", - "pylint[spelling]", - "mypy", + "cpplint", "doc8", + "mypy", + "pre-commit", "pyenchant", + "pylint[spelling]", + "ruff", "xdoctest", - "cpplint", - "pre-commit", ] test = [ "pytest", "pytest-cov", "covdefaults", + "xdoctest", "rich", # Pin to minimum for compatibility test "typing-extensions == 4.6.0; python_version < '3.13' and platform_system == 'Linux'", @@ -84,10 +85,10 @@ docs = [ "sphinx", "sphinx-autoapi", "sphinx-autobuild", + "sphinx-autodoc-typehints", "sphinx-copybutton", "sphinx-rtd-theme", "sphinxcontrib-bibtex", - "sphinx-autodoc-typehints", "docutils", "jax[cpu]", "numpy", diff --git a/src/optree.cpp b/src/optree.cpp index dfa27e43..1cfbc9bc 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -111,11 +111,9 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] return f'0x{self:08X}' BUILDTIME_METADATA.update( - **{ - name: HexInt(value) - for name, value in BUILDTIME_METADATA.items() - if name.endswith('_HEX') and isinstance(value, int) - }, + (name, HexInt(value)) + for name, value in BUILDTIME_METADATA.items() + if name.endswith('_HEX') and isinstance(value, int) ) BUILDTIME_METADATA = types.MappingProxyType(BUILDTIME_METADATA) @@ -277,10 +275,8 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] #endif auto * const PyTreeKind_Type = reinterpret_cast(PyTreeKindTypeObject.ptr()); PyTreeKind_Type->tp_name = "optree.PyTreeKind"; - py::setattr(PyTreeKindTypeObject.ptr(), Py_Get_ID(__module__), Py_Get_ID(optree)); - py::setattr(PyTreeKindTypeObject.ptr(), - "NUM_KINDS", - py::int_(py::ssize_t(PyTreeKind::NumKinds))); + py::setattr(PyTreeKindTypeObject, Py_Get_ID(__module__), Py_Get_ID(optree)); + py::setattr(PyTreeKindTypeObject, "NUM_KINDS", py::int_(py::ssize_t(PyTreeKind::NumKinds))); auto PyTreeSpecTypeObject = #if defined(PYBIND11_HAS_INTERNALS_WITH_SMART_HOLDER_SUPPORT) @@ -303,7 +299,7 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] py::module_local()); auto * const PyTreeSpec_Type = reinterpret_cast(PyTreeSpecTypeObject.ptr()); PyTreeSpec_Type->tp_name = "optree.PyTreeSpec"; - py::setattr(PyTreeSpecTypeObject.ptr(), Py_Get_ID(__module__), Py_Get_ID(optree)); + py::setattr(PyTreeSpecTypeObject, Py_Get_ID(__module__), Py_Get_ID(optree)); PyTreeSpecTypeObject .def("unflatten", @@ -501,7 +497,7 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] py::module_local()); auto * const PyTreeIter_Type = reinterpret_cast(PyTreeIterTypeObject.ptr()); PyTreeIter_Type->tp_name = "optree.PyTreeIter"; - py::setattr(PyTreeIterTypeObject.ptr(), Py_Get_ID(__module__), Py_Get_ID(optree)); + py::setattr(PyTreeIterTypeObject, Py_Get_ID(__module__), Py_Get_ID(optree)); PyTreeIterTypeObject .def(py::init, bool, std::string>(), diff --git a/src/treespec/serialization.cpp b/src/treespec/serialization.cpp index 99af5beb..b80770c9 100644 --- a/src/treespec/serialization.cpp +++ b/src/treespec/serialization.cpp @@ -297,7 +297,7 @@ py::object PyTreeSpec::ToPickleable() const { ssize_t i = 0; for (const auto &node : m_traversal) { const scoped_critical_section2 cs{ - node.custom != nullptr ? py::handle{node.custom->type.ptr()} : py::handle{}, + node.custom != nullptr ? py::handle{node.custom->type} : py::handle{}, node.node_data}; TupleSetItem(node_states, i++, diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py index df3ff8eb..82cef11e 100644 --- a/tests/concurrent/test_subinterpreters.py +++ b/tests/concurrent/test_subinterpreters.py @@ -30,7 +30,7 @@ pytest.skip('Test for CPython 3.14+ only', allow_module_level=True) -from concurrent.futures import InterpreterPoolExecutor +from concurrent.futures import InterpreterPoolExecutor, as_completed if Py_GIL_DISABLED and not Py_DEBUG: @@ -51,3 +51,16 @@ def run(func, /, *args, **kwargs): if exception is not None: raise exception return future.result() + + +def concurrent_run(func, /, *args, **kwargs): + futures = [EXECUTOR.submit(func, *args, **kwargs) for _ in range(NUM_FUTURES)] + future2index = {future: i for i, future in enumerate(futures)} + completed_futures = sorted(as_completed(futures), key=future2index.get) + first_exception = next(filter(None, (future.exception() for future in completed_futures)), None) + if first_exception is not None: + raise first_exception + return [future.result() for future in completed_futures] + + +run(object) # warm-up diff --git a/tests/concurrent/test_threading.py b/tests/concurrent/test_threading.py index f6de0fae..2d03897f 100644 --- a/tests/concurrent/test_threading.py +++ b/tests/concurrent/test_threading.py @@ -55,6 +55,14 @@ atexit.register(EXECUTOR.shutdown) +def run(func, /, *args, **kwargs): + future = EXECUTOR.submit(func, *args, **kwargs) + exception = future.exception() + if exception is not None: + raise exception + return future.result() + + def concurrent_run(func, /, *args, **kwargs): futures = [EXECUTOR.submit(func, *args, **kwargs) for _ in range(NUM_FUTURES)] future2index = {future: i for i, future in enumerate(futures)} @@ -65,7 +73,7 @@ def concurrent_run(func, /, *args, **kwargs): return [future.result() for future in completed_futures] -concurrent_run(object) # warm-up +run(object) # warm-up @parametrize( diff --git a/tests/requirements.txt b/tests/requirements.txt index 4ea69f96..58bc3c84 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -3,16 +3,19 @@ pytest pytest-cov covdefaults -rich -ruff -pylint[spelling] -mypy -doc8 -pyenchant xdoctest +rich + cpplint +doc8 +mypy pre-commit +pyenchant +pylint[spelling] +ruff + +--extra-index-url https://download.pytorch.org/whl/cpu -jax[cpu] >= 0.4.6 +jax[cpu] numpy torch From 65a8fce10f69cc90d688efaa30795a61beb15a8a Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 29 Nov 2025 19:49:26 +0800 Subject: [PATCH 05/59] test: add import tests --- tests/concurrent/test_subinterpreters.py | 26 +++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py index 82cef11e..26969491 100644 --- a/tests/concurrent/test_subinterpreters.py +++ b/tests/concurrent/test_subinterpreters.py @@ -14,6 +14,9 @@ # ============================================================================== import atexit +import contextlib +import importlib +import re import sys import pytest @@ -26,10 +29,16 @@ ) -if PYPY or WASM or sys.version_info < (3, 14): +if ( + PYPY + or WASM + or sys.version_info < (3, 14) + or not getattr(sys.implementation, 'supports_isolated_interpreters', False) +): pytest.skip('Test for CPython 3.14+ only', allow_module_level=True) +from concurrent import interpreters from concurrent.futures import InterpreterPoolExecutor, as_completed @@ -64,3 +73,18 @@ def concurrent_run(func, /, *args, **kwargs): run(object) # warm-up + + +def test_import_failure(): + pattern = re.escape('module optree._C does not support loading in subinterpreters') + with pytest.raises(ImportError, match=pattern) as excinfo: + run(importlib.import_module, 'optree') + + with contextlib.closing(interpreters.create()) as subinterpreter: + with pytest.raises(interpreters.ExecutionFailed, match=pattern) as excinfo: + subinterpreter.call(importlib.import_module, 'optree') + assert excinfo.value.excinfo.type.__name__ == 'ImportError' + + with pytest.raises(interpreters.ExecutionFailed, match=pattern) as excinfo: + subinterpreter.exec('import optree') + assert excinfo.value.excinfo.type.__name__ == 'ImportError' From 261729a219df37aa6fe76367cef3b6ee1ba63159 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 29 Nov 2025 21:30:39 +0800 Subject: [PATCH 06/59] chore(workflows): update find options --- .github/workflows/tests-with-pydebug.yml | 42 +++++++++----- .github/workflows/tests.yml | 71 +++++++++++++++++------- 2 files changed, 79 insertions(+), 34 deletions(-) diff --git a/.github/workflows/tests-with-pydebug.yml b/.github/workflows/tests-with-pydebug.yml index ee48a90f..722fca3a 100644 --- a/.github/workflows/tests-with-pydebug.yml +++ b/.github/workflows/tests-with-pydebug.yml @@ -348,10 +348,15 @@ jobs: "--junit-xml=junit-${{ env.PYTHON_TAG }}-${{ runner.os }}.xml" ) make test PYTESTOPTS="${PYTESTOPTS[*]}" - if [[ -n "$(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp")" ]]; then + + CORE_DUMP_FILES="$( + find . -type d -path "./venv" -prune \ + -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + )" + if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "::error::Coredump files found, indicating a crash during tests." >&2 echo "Coredump files:" >&2 - ls -alh $(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp") >&2 + ls -alh ${CORE_DUMP_FILES} >&2 exit 1 fi @@ -369,9 +374,13 @@ jobs: if: ${{ !cancelled() }} shell: bash run: | - if [[ -n "$(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp")" ]]; then + CORE_DUMP_FILES="$( + find . -type d -path "./venv" -prune \ + -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + )" + if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "Found core dumps:" - ls -alh $(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp") + ls -alh ${CORE_DUMP_FILES} BACKTRACE_COMMAND="" if [[ "${{ runner.os }}" == 'Linux' ]]; then echo "::group::Install GDB" @@ -392,18 +401,23 @@ jobs: fi if [[ -n "${BACKTRACE_COMMAND}" ]]; then echo "Collecting backtraces:" - find . -iname "core.*.[1-9]*" -exec bash -xc " - echo '::group::backtrace from: {}'; - ${BACKTRACE_COMMAND}; - echo '::endgroup::'; - " ';' - find . -iname "core_*.dmp" -exec bash -xc " - echo '::group::backtrace from: {}'; - ${BACKTRACE_COMMAND}; - echo '::endgroup::'; - " ';' + find . -type d -path "./venv" -prune \ + -o -iname "core.*.[1-9]*" \ + -exec bash -xc " + echo '::group::backtrace from: {}'; + ${BACKTRACE_COMMAND}; + echo '::endgroup::'; + " ';' + find . -type d -path "./venv" -prune \ + -o -iname "core_*.dmp" \ + -exec bash -xc " + echo '::group::backtrace from: {}'; + ${BACKTRACE_COMMAND}; + echo '::endgroup::'; + " ';' fi echo "::warning::Coredump files found, see backtraces above for details." >&2 + exit 1 fi - name: Upload coverage to Codecov diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 87cee51a..7cd1dda6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -209,14 +209,19 @@ jobs: popd rm -rf venv ) + if [[ "$?" -ne 0 ]]; then echo "::error::Failed to install with C++17." >&2 exit 1 fi - if [[ -n "$(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp")" ]]; then + CORE_DUMP_FILES="$( + find . -type d -path "./venv" -prune \ + -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + )" + if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "::error::Coredump files found, indicating a crash during tests." >&2 echo "Coredump files:" >&2 - ls -alh $(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp") >&2 + ls -alh ${CORE_DUMP_FILES} >&2 exit 1 fi @@ -251,6 +256,7 @@ jobs: "${{ env.PYTHON }}" ) fi + ( set -ex ${{ env.PYTHON }} -m venv venv @@ -262,14 +268,19 @@ jobs: popd rm -rf venv ) + if [[ "$?" -ne 0 ]]; then echo "::error::Failed to install with CMake from PyPI." >&2 exit 1 fi - if [[ -n "$(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp")" ]]; then + CORE_DUMP_FILES="$( + find . -type d -path "./venv" -prune \ + -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + )" + if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "::error::Coredump files found, indicating a crash during tests." >&2 echo "Coredump files:" >&2 - ls -alh $(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp") >&2 + ls -alh ${CORE_DUMP_FILES} >&2 exit 1 fi @@ -290,6 +301,7 @@ jobs: "${{ env.PYTHON }}" ) fi + ( set -ex ${{ env.PYTHON }} -m venv venv @@ -300,14 +312,19 @@ jobs: popd rm -rf venv ) + if [[ "$?" -ne 0 ]]; then echo "::error::Failed to install with CMake from PyPI (no system CMake)." >&2 exit 1 fi - if [[ -n "$(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp")" ]]; then + CORE_DUMP_FILES="$( + find . -type d -path "./venv" -prune \ + -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + )" + if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "::error::Coredump files found, indicating a crash during tests." >&2 echo "Coredump files:" >&2 - ls -alh $(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp") >&2 + ls -alh ${CORE_DUMP_FILES} >&2 exit 1 fi @@ -323,10 +340,15 @@ jobs: "--junit-xml=junit-${{ env.PYTHON_TAG }}-${{ runner.os }}.xml" ) make test PYTESTOPTS="${PYTESTOPTS[*]}" - if [[ -n "$(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp")" ]]; then + + CORE_DUMP_FILES="$( + find . -type d -path "./venv" -prune \ + -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + )" + if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "::error::Coredump files found, indicating a crash during tests." >&2 echo "Coredump files:" >&2 - ls -alh $(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp") >&2 + ls -alh ${CORE_DUMP_FILES} >&2 exit 1 fi @@ -344,9 +366,13 @@ jobs: if: ${{ !cancelled() }} shell: bash run: | - if [[ -n "$(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp")" ]]; then + CORE_DUMP_FILES="$( + find . -type d -path "./venv" -prune \ + -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + )" + if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "Found core dumps:" - ls -alh $(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp") + ls -alh ${CORE_DUMP_FILES} BACKTRACE_COMMAND="" if [[ "${{ runner.os }}" == 'Linux' ]]; then echo "::group::Install GDB" @@ -367,18 +393,23 @@ jobs: fi if [[ -n "${BACKTRACE_COMMAND}" ]]; then echo "Collecting backtraces:" - find . -iname "core.*.[1-9]*" -exec bash -xc " - echo '::group::backtrace from: {}'; - ${BACKTRACE_COMMAND}; - echo '::endgroup::'; - " ';' - find . -iname "core_*.dmp" -exec bash -xc " - echo '::group::backtrace from: {}'; - ${BACKTRACE_COMMAND}; - echo '::endgroup::'; - " ';' + find . -type d -path "./venv" -prune \ + -o -iname "core.*.[1-9]*" \ + -exec bash -xc " + echo '::group::backtrace from: {}'; + ${BACKTRACE_COMMAND}; + echo '::endgroup::'; + " ';' + find . -type d -path "./venv" -prune \ + -o -iname "core_*.dmp" \ + -exec bash -xc " + echo '::group::backtrace from: {}'; + ${BACKTRACE_COMMAND}; + echo '::endgroup::'; + " ';' fi echo "::warning::Coredump files found, see backtraces above for details." >&2 + exit 1 fi - name: Upload coverage to Codecov From 6b9b38a6f4cfa141084117ffd9a1dea333abf2c8 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sun, 30 Nov 2025 16:58:56 +0800 Subject: [PATCH 07/59] feat: set `py::multiple_interpreters::shared_gil()` --- include/optree/registry.h | 5 +++ optree/_C.pyi | 3 ++ src/optree.cpp | 37 +++++++++++++++------ src/registry.cpp | 19 +++++++++++ tests/concurrent/test_subinterpreters.py | 6 ++++ tests/helpers.py | 24 ++++++++++---- tests/test_registry.py | 41 ++++++++++++++++++++++++ tests/test_typing.py | 4 +-- 8 files changed, 122 insertions(+), 17 deletions(-) diff --git a/include/optree/registry.h b/include/optree/registry.h index d39c792e..27c31bea 100644 --- a/include/optree/registry.h +++ b/include/optree/registry.h @@ -19,6 +19,7 @@ limitations under the License. #include // std::uint8_t #include // std::shared_ptr +#include // std::optional, std::nullopt #include // std::string #include // std::unordered_map #include // std::unordered_set @@ -98,6 +99,10 @@ class PyTreeTypeRegistry { using RegistrationPtr = std::shared_ptr; + // Gets the number of registered types. + [[nodiscard]] ssize_t Size( + const std::optional ®istry_namespace = std::nullopt) const; + // Registers a new custom type. Objects of `cls` will be treated as container node types in // PyTrees. static void Register(const py::object &cls, diff --git a/optree/_C.pyi b/optree/_C.pyi index ee5ced47..4c92f799 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -213,3 +213,6 @@ def set_dict_insertion_ordered( /, namespace: str = '', ) -> None: ... +def registry_size( + namespace: str | None = None, +) -> int: ... diff --git a/src/optree.cpp b/src/optree.cpp index 1cfbc9bc..49582070 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -31,6 +31,9 @@ limitations under the License. # include #endif +// NOLINTNEXTLINE[bugprone-macro-parentheses] +#define NONZERO_OR_EMPTY(MACRO) ((MACRO + 0 != 0) || (0 - MACRO - 1 >= 0)) + namespace optree { py::module_ GetCxxModule(const std::optional &module) { @@ -52,9 +55,6 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] std::string(__FILE_RELPATH_FROM_PROJECT_ROOT__) + ")"; mod.attr("Py_TPFLAGS_BASETYPE") = py::int_(Py_TPFLAGS_BASETYPE); - // NOLINTNEXTLINE[bugprone-macro-parentheses] -#define NONZERO_OR_EMPTY(MACRO) ((MACRO + 0 != 0) || (0 - MACRO - 1 >= 0)) - // Meta information during build py::dict BUILDTIME_METADATA{}; BUILDTIME_METADATA["PY_VERSION"] = py::str(PY_VERSION); @@ -99,8 +99,6 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] BUILDTIME_METADATA["GLIBCXX_USE_CXX11_ABI"] = py::bool_(false); #endif -#undef NONZERO_OR_EMPTY - mod.attr("BUILDTIME_METADATA") = std::move(BUILDTIME_METADATA); py::exec( R"py( @@ -143,6 +141,19 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] py::arg("cls"), py::pos_only(), py::arg("namespace") = "") + .def( + "registry_size", + [](const std::optional ®istry_namespace) { + const ssize_t count = + PyTreeTypeRegistry::Singleton()->Size(registry_namespace); + EXPECT_EQ(count, + PyTreeTypeRegistry::Singleton()->Size(registry_namespace) + 1, + "The number of registered types in the two registries should match " + "up to the extra None type in the NoneIsLeaf=false registry."); + return count; + }, + "Get the number of registered types.", + py::arg("namespace") = std::nullopt) .def("is_dict_insertion_ordered", &PyTreeSpec::IsDictInsertionOrdered, "Return whether need to preserve the dict insertion order during flattening.", @@ -528,10 +539,18 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] } // namespace optree +// NOLINTBEGIN[cppcoreguidelines-pro-bounds-pointer-arithmetic,cppcoreguidelines-pro-type-vararg] #if PYBIND11_VERSION_HEX >= 0x020D00F0 // pybind11 2.13.0 -// NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic,cppcoreguidelines-pro-type-vararg] -PYBIND11_MODULE(_C, mod, py::mod_gil_not_used()) { optree::BuildModule(mod); } +# if defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \ + NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) +PYBIND11_MODULE(_C, mod, py::mod_gil_not_used(), py::multiple_interpreters::shared_gil()) +# else +PYBIND11_MODULE(_C, mod, py::mod_gil_not_used()) +# endif #else -// NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic,cppcoreguidelines-pro-type-vararg] -PYBIND11_MODULE(_C, mod) { optree::BuildModule(mod); } +PYBIND11_MODULE(_C, mod) #endif +{ + optree::BuildModule(mod); +} +// NOLINTEND[cppcoreguidelines-pro-bounds-pointer-arithmetic,cppcoreguidelines-pro-type-vararg] diff --git a/src/registry.cpp b/src/registry.cpp index 8f991434..03fa8ba3 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -16,6 +16,7 @@ limitations under the License. */ #include // std::make_shared +#include // std::optional #include // std::ostringstream #include // std::string #include // std::remove_const_t @@ -65,6 +66,24 @@ template template PyTreeTypeRegistry *PyTreeTypeRegistry::Singleton(); template PyTreeTypeRegistry *PyTreeTypeRegistry::Singleton(); +ssize_t PyTreeTypeRegistry::Size(const std::optional ®istry_namespace) const { + const scoped_read_lock lock{sm_mutex}; + + if (registry_namespace.has_value()) [[unlikely]] { + if (registry_namespace->empty()) [[likely]] { + return static_cast(m_registrations.size()); + } + ssize_t count = 0; + for (const auto &entry : m_named_registrations) { + if (entry.first.first == *registry_namespace) [[likely]] { + ++count; + } + } + return static_cast(m_registrations.size()) + count; + } + return static_cast(m_registrations.size() + m_named_registrations.size()); +} + template /*static*/ void PyTreeTypeRegistry::RegisterImpl(const py::object &cls, const py::function &flatten_func, diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py index 26969491..65d6425a 100644 --- a/tests/concurrent/test_subinterpreters.py +++ b/tests/concurrent/test_subinterpreters.py @@ -22,6 +22,9 @@ import pytest from helpers import ( + ANDROID, + IOS, + PYBIND11_HAS_SUBINTERPRETER_SUPPORT, PYPY, WASM, Py_DEBUG, @@ -32,8 +35,11 @@ if ( PYPY or WASM + or IOS + or ANDROID or sys.version_info < (3, 14) or not getattr(sys.implementation, 'supports_isolated_interpreters', False) + or not PYBIND11_HAS_SUBINTERPRETER_SUPPORT ): pytest.skip('Test for CPython 3.14+ only', allow_module_level=True) diff --git a/tests/helpers.py b/tests/helpers.py index e81a103f..5bf0d8aa 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -32,25 +32,37 @@ import pytest import optree -import optree._C +from optree._C import ( + PYBIND11_HAS_NATIVE_ENUM, + PYBIND11_HAS_SUBINTERPRETER_SUPPORT, + Py_DEBUG, + Py_GIL_DISABLED, + registry_size, +) from optree.registry import __GLOBAL_NAMESPACE as GLOBAL_NAMESPACE +from optree.registry import _NODETYPE_REGISTRY as NODETYPE_REGISTRY TEST_ROOT = Path(__file__).absolute().parent +INITIAL_REGISTRY_SIZE = registry_size() +assert INITIAL_REGISTRY_SIZE == 8 +assert INITIAL_REGISTRY_SIZE + 2 == len(NODETYPE_REGISTRY) + +_ = PYBIND11_HAS_NATIVE_ENUM +_ = PYBIND11_HAS_SUBINTERPRETER_SUPPORT + if sysconfig.get_config_var('Py_DEBUG') is None: - Py_DEBUG = hasattr(sys, 'gettotalrefcount') + assert Py_DEBUG == hasattr(sys, 'gettotalrefcount') else: - Py_DEBUG = bool(int(sysconfig.get_config_var('Py_DEBUG') or '0')) -assert Py_DEBUG == optree._C.Py_DEBUG + assert Py_DEBUG == bool(int(sysconfig.get_config_var('Py_DEBUG') or '0')) skipif_pydebug = pytest.mark.skipif( Py_DEBUG, reason='Py_DEBUG is enabled which causes too much overhead', ) -Py_GIL_DISABLED = bool(int(sysconfig.get_config_var('Py_GIL_DISABLED') or '0')) -assert Py_GIL_DISABLED == optree._C.Py_GIL_DISABLED +assert Py_GIL_DISABLED == bool(int(sysconfig.get_config_var('Py_GIL_DISABLED') or '0')) skipif_freethreading = pytest.mark.skipif( Py_GIL_DISABLED, reason='Py_GIL_DISABLED is set', diff --git a/tests/test_registry.py b/tests/test_registry.py index c78a5e9a..7f9100c8 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -27,6 +27,7 @@ import optree._C from helpers import ( GLOBAL_NAMESPACE, + NODETYPE_REGISTRY, PYPY, Py_GIL_DISABLED, disable_systrace, @@ -1273,3 +1274,43 @@ def tree_unflatten(cls, metadata, children): assert hasattr(WrappingTestClass, 'tree_unflatten') assert callable(WrappingTestClass.tree_flatten) assert callable(WrappingTestClass.tree_unflatten) + + +def test_registry_size(): + initial_size = len(optree.register_pytree_node.get()) + initial_total_size = len(NODETYPE_REGISTRY) + initial_registry_size = len(optree.register_pytree_node.get(namespace='mylist')) + assert optree._C.registry_size() + 2 == initial_total_size + assert optree._C.registry_size(namespace='') + 2 == initial_size + assert optree._C.registry_size(namespace='mylist') + 2 == initial_registry_size + assert optree._C.registry_size(namespace='undefined') + 2 == initial_size + assert len(optree.register_pytree_node.get(namespace='undefined')) == initial_size + + @optree.register_pytree_node_class(namespace='mylist') + class MyList(UserList): + def __tree_flatten__(self): + return self.data, None, None + + @classmethod + def __tree_unflatten__(cls, metadata, children): + return cls(children) + + assert len(optree.register_pytree_node.get()) == initial_size + assert len(NODETYPE_REGISTRY) == initial_total_size + 1 + assert len(optree.register_pytree_node.get(namespace='mylist')) == initial_registry_size + 1 + assert optree._C.registry_size() + 2 == initial_total_size + 1 + assert optree._C.registry_size(namespace='') + 2 == initial_size + assert optree._C.registry_size(namespace='mylist') + 2 == initial_registry_size + 1 + assert optree._C.registry_size(namespace='undefined') + 2 == initial_size + assert len(optree.register_pytree_node.get(namespace='undefined')) == initial_size + + optree.unregister_pytree_node(MyList, namespace='mylist') + + assert len(optree.register_pytree_node.get()) == initial_size + assert len(NODETYPE_REGISTRY) == initial_total_size + assert len(optree.register_pytree_node.get(namespace='mylist')) == initial_registry_size + assert optree._C.registry_size() + 2 == initial_total_size + assert optree._C.registry_size(namespace='') + 2 == initial_size + assert optree._C.registry_size(namespace='mylist') + 2 == initial_registry_size + assert optree._C.registry_size(namespace='undefined') + 2 == initial_size + assert len(optree.register_pytree_node.get(namespace='undefined')) == initial_size diff --git a/tests/test_typing.py b/tests/test_typing.py index 3a4400ad..4e428735 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -26,8 +26,8 @@ import pytest import optree -import optree._C from helpers import ( + PYBIND11_HAS_NATIVE_ENUM, CustomNamedTupleSubclass, CustomTuple, Py_GIL_DISABLED, @@ -72,7 +72,7 @@ class FakeStructSequence(tuple): def test_pytreekind_enum(): - if optree._C.PYBIND11_HAS_NATIVE_ENUM: + if PYBIND11_HAS_NATIVE_ENUM: all_kinds = list(optree.PyTreeKind) assert len(all_kinds) == optree.PyTreeKind.NUM_KINDS assert issubclass(optree.PyTreeKind, enum.IntEnum) From e4f201947f3d8e098d5a51f3eb331ac18778ea5c Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 1 Dec 2025 00:50:51 +0800 Subject: [PATCH 08/59] feat: set `py::multiple_interpreters::per_interpreter_gil()` --- include/optree/registry.h | 14 ++- optree/_C.pyi | 6 +- src/optree.cpp | 89 +++++++++---- src/registry.cpp | 152 +++++++++++++++++------ src/treespec/gc.cpp | 4 +- tests/concurrent/test_subinterpreters.py | 25 ++-- tests/helpers.py | 4 +- tests/test_registry.py | 24 ++-- 8 files changed, 225 insertions(+), 93 deletions(-) diff --git a/include/optree/registry.h b/include/optree/registry.h index 27c31bea..bf1eed70 100644 --- a/include/optree/registry.h +++ b/include/optree/registry.h @@ -67,6 +67,14 @@ constexpr PyTreeKind kDeque = PyTreeKind::Deque; constexpr PyTreeKind kStructSequence = PyTreeKind::StructSequence; constexpr PyTreeKind kNumPyTreeKinds = PyTreeKind::NumKinds; +[[nodiscard]] inline ssize_t GetPyInterpreterID() { + PyInterpreterState *interp = PyInterpreterState_Get(); + if (interp == nullptr) [[unlikely]] { + throw std::runtime_error("Failed to get the current Python interpreter state."); + } + return PyInterpreterState_GetID(interp); +} + // Registry of custom node types. class PyTreeTypeRegistry { public: @@ -141,13 +149,17 @@ class PyTreeTypeRegistry { const std::string ®istry_namespace); // Clear the registry on cleanup. - static void Clear(); + static void Clear(const std::optional &interpreter_id = std::nullopt); std::unordered_map m_registrations{}; std::unordered_map, RegistrationPtr> m_named_registrations{}; static inline std::unordered_set sm_builtins_types{}; + static inline std::unordered_map> + sm_interpreter_scoped_registered_types{}; static inline read_write_mutex sm_mutex{}; + static inline ssize_t sm_num_interpreters_alive = 0; + static inline ssize_t sm_num_interpreters_seen = 0; }; } // namespace optree diff --git a/optree/_C.pyi b/optree/_C.pyi index 4c92f799..feae03d8 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -213,6 +213,6 @@ def set_dict_insertion_ordered( /, namespace: str = '', ) -> None: ... -def registry_size( - namespace: str | None = None, -) -> int: ... +def get_num_interpreters_seen() -> int: ... +def get_num_interpreters_alive() -> int: ... +def get_registry_size(namespace: str | None = None) -> int: ... diff --git a/src/optree.cpp b/src/optree.cpp index 49582070..b21f92c6 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -17,11 +17,12 @@ limitations under the License. #include "optree/optree.h" -#include // std::{not_,}equal_to, std::less{,_equal}, std::greater{,_equal} -#include // std::unique_ptr -#include // std::optional, std::nullopt -#include // std::string -#include // std::move +#include // std::{not_,}equal_to, std::less{,_equal}, std::greater{,_equal} +#include // std::unique_ptr +#include // std::optional, std::nullopt +#include // std::string +#include // std::unordered_set +#include // std::move #include #include @@ -141,19 +142,6 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] py::arg("cls"), py::pos_only(), py::arg("namespace") = "") - .def( - "registry_size", - [](const std::optional ®istry_namespace) { - const ssize_t count = - PyTreeTypeRegistry::Singleton()->Size(registry_namespace); - EXPECT_EQ(count, - PyTreeTypeRegistry::Singleton()->Size(registry_namespace) + 1, - "The number of registered types in the two registries should match " - "up to the extra None type in the NoneIsLeaf=false registry."); - return count; - }, - "Get the number of registered types.", - py::arg("namespace") = std::nullopt) .def("is_dict_insertion_ordered", &PyTreeSpec::IsDictInsertionOrdered, "Return whether need to preserve the dict insertion order during flattening.", @@ -165,6 +153,56 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] py::arg("mode"), py::pos_only(), py::arg("namespace") = "") + .def( + "get_num_interpreters_seen", + []() -> size_t { + const scoped_read_lock lock{PyTreeTypeRegistry::sm_mutex}; + return PyTreeTypeRegistry::sm_num_interpreters_seen; + }, + "Get the number of interpreters that have seen the registry.") + .def( + "get_num_interpreters_alive", + []() -> size_t { + const scoped_read_lock lock{PyTreeTypeRegistry::sm_mutex}; + EXPECT_EQ(py::ssize_t_cast( + PyTreeTypeRegistry::sm_interpreter_scoped_registered_types.size()), + PyTreeTypeRegistry::sm_num_interpreters_alive, + "The number of alive interpreters should match the size of the " + "interpreter-scoped registered types map."); + return PyTreeTypeRegistry::sm_num_interpreters_alive; + }, + "Get the number of alive interpreters that have seen the registry.") + .def( + "get_alive_interpreter_ids", + []() -> std::unordered_set { + const scoped_read_lock lock{PyTreeTypeRegistry::sm_mutex}; + EXPECT_EQ(py::ssize_t_cast( + PyTreeTypeRegistry::sm_interpreter_scoped_registered_types.size()), + PyTreeTypeRegistry::sm_num_interpreters_alive, + "The number of alive interpreters should match the size of the " + "interpreter-scoped registered types map."); + std::unordered_set ids; + for (const auto &[id, _] : + PyTreeTypeRegistry::sm_interpreter_scoped_registered_types) { + ids.insert(id); + } + return ids; + }, + "Get the IDs of alive interpreters that have seen the registry.") + .def( + "get_registry_size", + [](const std::optional ®istry_namespace) { + const ssize_t count = + PyTreeTypeRegistry::Singleton()->Size(registry_namespace); + EXPECT_EQ( + count, + PyTreeTypeRegistry::Singleton()->Size(registry_namespace) + 1, + "The number of registered types in the two registries should match " + "up to the extra None type in the NoneIsNode registry."); + return count; + }, + "Get the number of registered types.", + py::arg("namespace") = std::nullopt) .def("flatten", &PyTreeSpec::Flatten, "Flattens a pytree.", @@ -533,8 +571,17 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] PyType_Modified(PyTreeSpec_Type); PyType_Modified(PyTreeIter_Type); - py::getattr(py::module_::import("atexit"), - "register")(py::cpp_function(&PyTreeTypeRegistry::Clear)); + const ssize_t interpreter_id = GetPyInterpreterID(); + + { + const scoped_write_lock lock{PyTreeTypeRegistry::sm_mutex}; + ++PyTreeTypeRegistry::sm_num_interpreters_alive; + ++PyTreeTypeRegistry::sm_num_interpreters_seen; + } + (void)PyTreeTypeRegistry::Singleton(); + (void)PyTreeTypeRegistry::Singleton(); + py::getattr(py::module_::import("atexit"), "register")(py::cpp_function( + [interpreter_id]() -> void { PyTreeTypeRegistry::Clear(interpreter_id); })); } } // namespace optree @@ -543,7 +590,7 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] #if PYBIND11_VERSION_HEX >= 0x020D00F0 // pybind11 2.13.0 # if defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \ NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) -PYBIND11_MODULE(_C, mod, py::mod_gil_not_used(), py::multiple_interpreters::shared_gil()) +PYBIND11_MODULE(_C, mod, py::mod_gil_not_used(), py::multiple_interpreters::per_interpreter_gil()) # else PYBIND11_MODULE(_C, mod, py::mod_gil_not_used()) # endif diff --git a/src/registry.cpp b/src/registry.cpp index 03fa8ba3..6f67eeaa 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -15,12 +15,14 @@ limitations under the License. ================================================================================ */ -#include // std::make_shared -#include // std::optional -#include // std::ostringstream -#include // std::string -#include // std::remove_const_t -#include // std::move, std::make_pair +#include // std::make_shared +#include // std::optional +#include // std::ostringstream +#include // std::string +#include // std::remove_const_t +#include // std::unordered_set +#include // std::move, std::make_pair +#include // std::vector #include @@ -35,8 +37,15 @@ template .call_once_and_store_result([]() -> PyTreeTypeRegistry { PyTreeTypeRegistry registry{}; - const auto add_builtin_type = [®istry](const py::object &cls, - const PyTreeKind &kind) -> void { + const ssize_t interpreter_id = GetPyInterpreterID(); + auto &interpreter_scoped_registered_types = + sm_interpreter_scoped_registered_types.try_emplace(interpreter_id) + .first->second; + + const auto add_builtin_type = [®istry, + &interpreter_scoped_registered_types]( + const py::object &cls, + const PyTreeKind &kind) -> void { auto registration = std::make_shared>(); registration->kind = kind; @@ -47,6 +56,7 @@ template " is already registered in the global namespace."); if (sm_builtins_types.emplace(cls).second) [[likely]] { cls.inc_ref(); + interpreter_scoped_registered_types.emplace(cls); } }; if constexpr (!NoneIsLeaf) { @@ -69,19 +79,27 @@ template PyTreeTypeRegistry *PyTreeTypeRegistry::Singleton(); ssize_t PyTreeTypeRegistry::Size(const std::optional ®istry_namespace) const { const scoped_read_lock lock{sm_mutex}; - if (registry_namespace.has_value()) [[unlikely]] { - if (registry_namespace->empty()) [[likely]] { - return static_cast(m_registrations.size()); + const ssize_t interpreter_id = GetPyInterpreterID(); + const auto &it = sm_interpreter_scoped_registered_types.find(interpreter_id); + const auto &interpreter_scoped_registered_types = + (it != sm_interpreter_scoped_registered_types.end() ? it->second + : std::unordered_set{}); + + ssize_t count = 0; + for (const auto &[type, _] : m_registrations) { + if (interpreter_scoped_registered_types.find(type) != + interpreter_scoped_registered_types.end()) [[likely]] { + ++count; } - ssize_t count = 0; - for (const auto &entry : m_named_registrations) { - if (entry.first.first == *registry_namespace) [[likely]] { - ++count; - } + } + for (const auto &[named_type, _] : m_named_registrations) { + if (interpreter_scoped_registered_types.find(named_type.second) != + interpreter_scoped_registered_types.end() && + (!registry_namespace || named_type.first == *registry_namespace)) [[likely]] { + ++count; } - return static_cast(m_registrations.size()) + count; } - return static_cast(m_registrations.size() + m_named_registrations.size()); + return count; } template @@ -178,6 +196,10 @@ template flatten_func.inc_ref(); unflatten_func.inc_ref(); path_entry_type.inc_ref(); + + const ssize_t interpreter_id = GetPyInterpreterID(); + const auto &it = sm_interpreter_scoped_registered_types.try_emplace(interpreter_id).first; + it->second.emplace(cls); } template @@ -309,7 +331,7 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( const std::string &); // NOLINTNEXTLINE[readability-function-cognitive-complexity] -/*static*/ void PyTreeTypeRegistry::Clear() { +/*static*/ void PyTreeTypeRegistry::Clear(const std::optional &interpreter_id) { const scoped_write_lock lock{sm_mutex}; PyTreeTypeRegistry * const registry1 = PyTreeTypeRegistry::Singleton(); @@ -323,23 +345,21 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( for (const auto &cls : sm_builtins_types) { EXPECT_NE(registry1->m_registrations.find(cls), registry1->m_registrations.end()); } - for (const auto &entry : registry2->m_registrations) { - const auto it = registry1->m_registrations.find(entry.first); - EXPECT_NE(it, registry1->m_registrations.end()); + for (const auto &[cls2, registration2] : registry2->m_registrations) { + const auto it1 = registry1->m_registrations.find(cls2); + EXPECT_NE(it1, registry1->m_registrations.end()); - const auto ®istration1 = it->second; - const auto ®istration2 = entry.second; + const auto ®istration1 = it1->second; EXPECT_TRUE(registration1->type.is(registration2->type)); EXPECT_TRUE(registration1->flatten_func.is(registration2->flatten_func)); EXPECT_TRUE(registration1->unflatten_func.is(registration2->unflatten_func)); EXPECT_TRUE(registration1->path_entry_type.is(registration2->path_entry_type)); } - for (const auto &entry : registry2->m_named_registrations) { - const auto it = registry1->m_named_registrations.find(entry.first); - EXPECT_NE(it, registry1->m_named_registrations.end()); + for (const auto &[cls2, registration2] : registry2->m_named_registrations) { + const auto it1 = registry1->m_named_registrations.find(cls2); + EXPECT_NE(it1, registry1->m_named_registrations.end()); - const auto ®istration1 = it->second; - const auto ®istration2 = entry.second; + const auto ®istration1 = it1->second; EXPECT_TRUE(registration1->type.is(registration2->type)); EXPECT_TRUE(registration1->flatten_func.is(registration2->flatten_func)); EXPECT_TRUE(registration1->unflatten_func.is(registration2->unflatten_func)); @@ -347,17 +367,71 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( } #endif - for (const auto &entry : registry1->m_registrations) { - entry.second->type.dec_ref(); - entry.second->flatten_func.dec_ref(); - entry.second->unflatten_func.dec_ref(); - entry.second->path_entry_type.dec_ref(); + EXPECT_EQ(py::ssize_t_cast(sm_interpreter_scoped_registered_types.size()), + sm_num_interpreters_alive); + + if (!interpreter_id) [[unlikely]] { + EXPECT_EQ(sm_num_interpreters_alive, + 1, + "interpreter_id must be provided when multiple interpreters are alive."); + } + + if (sm_num_interpreters_alive > 1) [[unlikely]] { + EXPECT_TRUE(interpreter_id, + "interpreter_id must be provided when multiple interpreters are alive."); + const auto it = sm_interpreter_scoped_registered_types.find(*interpreter_id); + if (it != sm_interpreter_scoped_registered_types.end()) [[likely]] { + for (const auto &cls : it->second) { + const auto reg_it1 = registry1->m_registrations.find(cls); + if (reg_it1 != registry1->m_registrations.end()) [[likely]] { + const auto registration1 = reg_it1->second; + registration1->type.dec_ref(); + registration1->flatten_func.dec_ref(); + registration1->unflatten_func.dec_ref(); + registration1->path_entry_type.dec_ref(); + registry1->m_registrations.erase(reg_it1); + } + std::vector> named_types_to_erase{}; + for (const auto &[named_type, _] : registry1->m_named_registrations) { + if (named_type.second.is(cls)) [[likely]] { + named_types_to_erase.emplace_back(named_type); + } + } + for (const auto &named_type : named_types_to_erase) { + const auto named_it1 = registry1->m_named_registrations.find(named_type); + EXPECT_NE(named_it1, registry1->m_named_registrations.end()); + const auto registration1 = named_it1->second; + registration1->type.dec_ref(); + registration1->flatten_func.dec_ref(); + registration1->unflatten_func.dec_ref(); + registration1->path_entry_type.dec_ref(); + registry1->m_named_registrations.erase(named_it1); + } + registry2->m_registrations.erase(cls); + for (const auto &named_type : named_types_to_erase) { + registry2->m_named_registrations.erase(named_type); + } + } + sm_interpreter_scoped_registered_types.erase(it); + } else [[unlikely]] { + INTERNAL_ERROR("Interpreter ID " + std::to_string(*interpreter_id) + + " not found in `sm_interpreter_scoped_registered_types`."); + } + --sm_num_interpreters_alive; + return; + } + + for (const auto &[_, registration] : registry1->m_registrations) { + registration->type.dec_ref(); + registration->flatten_func.dec_ref(); + registration->unflatten_func.dec_ref(); + registration->path_entry_type.dec_ref(); } - for (const auto &entry : registry1->m_named_registrations) { - entry.second->type.dec_ref(); - entry.second->flatten_func.dec_ref(); - entry.second->unflatten_func.dec_ref(); - entry.second->path_entry_type.dec_ref(); + for (const auto &[_, registration] : registry1->m_named_registrations) { + registration->type.dec_ref(); + registration->flatten_func.dec_ref(); + registration->unflatten_func.dec_ref(); + registration->path_entry_type.dec_ref(); } sm_builtins_types.clear(); diff --git a/src/treespec/gc.cpp b/src/treespec/gc.cpp index 025e6178..7cdb52c2 100644 --- a/src/treespec/gc.cpp +++ b/src/treespec/gc.cpp @@ -62,8 +62,8 @@ namespace optree { return 0; } auto &self = thread_safe_cast(py::handle{self_base}); - for (const auto &pair : self.m_agenda) { - Py_VISIT(pair.first.ptr()); + for (const auto &[obj, _] : self.m_agenda) { + Py_VISIT(obj.ptr()); } Py_VISIT(self.m_root.ptr()); return 0; diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py index 65d6425a..0133f14c 100644 --- a/tests/concurrent/test_subinterpreters.py +++ b/tests/concurrent/test_subinterpreters.py @@ -15,8 +15,7 @@ import atexit import contextlib -import importlib -import re +import random import sys import pytest @@ -81,16 +80,16 @@ def concurrent_run(func, /, *args, **kwargs): run(object) # warm-up -def test_import_failure(): - pattern = re.escape('module optree._C does not support loading in subinterpreters') - with pytest.raises(ImportError, match=pattern) as excinfo: - run(importlib.import_module, 'optree') - - with contextlib.closing(interpreters.create()) as subinterpreter: - with pytest.raises(interpreters.ExecutionFailed, match=pattern) as excinfo: - subinterpreter.call(importlib.import_module, 'optree') - assert excinfo.value.excinfo.type.__name__ == 'ImportError' +def test_import(): + for _ in range(random.randint(5, 10)): + with contextlib.closing(interpreters.create()) as subinterpreter: + subinterpreter.exec('import optree') - with pytest.raises(interpreters.ExecutionFailed, match=pattern) as excinfo: + with contextlib.ExitStack() as stack: + subinterpreters = [ + stack.enter_context(contextlib.closing(interpreters.create())) + for _ in range(random.randint(5, 10)) + ] + random.shuffle(subinterpreters) + for subinterpreter in subinterpreters: subinterpreter.exec('import optree') - assert excinfo.value.excinfo.type.__name__ == 'ImportError' diff --git a/tests/helpers.py b/tests/helpers.py index 5bf0d8aa..ebe500bb 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -37,7 +37,7 @@ PYBIND11_HAS_SUBINTERPRETER_SUPPORT, Py_DEBUG, Py_GIL_DISABLED, - registry_size, + get_registry_size, ) from optree.registry import __GLOBAL_NAMESPACE as GLOBAL_NAMESPACE from optree.registry import _NODETYPE_REGISTRY as NODETYPE_REGISTRY @@ -46,7 +46,7 @@ TEST_ROOT = Path(__file__).absolute().parent -INITIAL_REGISTRY_SIZE = registry_size() +INITIAL_REGISTRY_SIZE = get_registry_size() assert INITIAL_REGISTRY_SIZE == 8 assert INITIAL_REGISTRY_SIZE + 2 == len(NODETYPE_REGISTRY) diff --git a/tests/test_registry.py b/tests/test_registry.py index 7f9100c8..256edb1f 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -1280,10 +1280,10 @@ def test_registry_size(): initial_size = len(optree.register_pytree_node.get()) initial_total_size = len(NODETYPE_REGISTRY) initial_registry_size = len(optree.register_pytree_node.get(namespace='mylist')) - assert optree._C.registry_size() + 2 == initial_total_size - assert optree._C.registry_size(namespace='') + 2 == initial_size - assert optree._C.registry_size(namespace='mylist') + 2 == initial_registry_size - assert optree._C.registry_size(namespace='undefined') + 2 == initial_size + assert optree._C.get_registry_size() + 2 == initial_total_size + assert optree._C.get_registry_size(namespace='') + 2 == initial_size + assert optree._C.get_registry_size(namespace='mylist') + 2 == initial_registry_size + assert optree._C.get_registry_size(namespace='undefined') + 2 == initial_size assert len(optree.register_pytree_node.get(namespace='undefined')) == initial_size @optree.register_pytree_node_class(namespace='mylist') @@ -1298,10 +1298,10 @@ def __tree_unflatten__(cls, metadata, children): assert len(optree.register_pytree_node.get()) == initial_size assert len(NODETYPE_REGISTRY) == initial_total_size + 1 assert len(optree.register_pytree_node.get(namespace='mylist')) == initial_registry_size + 1 - assert optree._C.registry_size() + 2 == initial_total_size + 1 - assert optree._C.registry_size(namespace='') + 2 == initial_size - assert optree._C.registry_size(namespace='mylist') + 2 == initial_registry_size + 1 - assert optree._C.registry_size(namespace='undefined') + 2 == initial_size + assert optree._C.get_registry_size() + 2 == initial_total_size + 1 + assert optree._C.get_registry_size(namespace='') + 2 == initial_size + assert optree._C.get_registry_size(namespace='mylist') + 2 == initial_registry_size + 1 + assert optree._C.get_registry_size(namespace='undefined') + 2 == initial_size assert len(optree.register_pytree_node.get(namespace='undefined')) == initial_size optree.unregister_pytree_node(MyList, namespace='mylist') @@ -1309,8 +1309,8 @@ def __tree_unflatten__(cls, metadata, children): assert len(optree.register_pytree_node.get()) == initial_size assert len(NODETYPE_REGISTRY) == initial_total_size assert len(optree.register_pytree_node.get(namespace='mylist')) == initial_registry_size - assert optree._C.registry_size() + 2 == initial_total_size - assert optree._C.registry_size(namespace='') + 2 == initial_size - assert optree._C.registry_size(namespace='mylist') + 2 == initial_registry_size - assert optree._C.registry_size(namespace='undefined') + 2 == initial_size + assert optree._C.get_registry_size() + 2 == initial_total_size + assert optree._C.get_registry_size(namespace='') + 2 == initial_size + assert optree._C.get_registry_size(namespace='mylist') + 2 == initial_registry_size + assert optree._C.get_registry_size(namespace='undefined') + 2 == initial_size assert len(optree.register_pytree_node.get(namespace='undefined')) == initial_size From eec46026652741ce1378eef5c712f6321fca20a7 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 1 Dec 2025 11:04:07 +0800 Subject: [PATCH 09/59] chore: disable subinterpreters for pre-3.14 --- include/optree/pymacros.h | 27 +++++++++++++++++++++++++++ include/optree/registry.h | 8 -------- optree/_C.pyi | 1 + src/optree.cpp | 5 +---- 4 files changed, 29 insertions(+), 12 deletions(-) diff --git a/include/optree/pymacros.h b/include/optree/pymacros.h index 7f2ac9ab..13841bf2 100644 --- a/include/optree/pymacros.h +++ b/include/optree/pymacros.h @@ -17,6 +17,8 @@ limitations under the License. #pragma once +#include // std::runtime_error + #include #include @@ -91,3 +93,28 @@ Py_Declare_ID(_asdict); // namedtuple._asdict Py_Declare_ID(n_fields); // structseq.n_fields Py_Declare_ID(n_sequence_fields); // structseq.n_sequence_fields Py_Declare_ID(n_unnamed_fields); // structseq.n_unnamed_fields + +// NOLINTNEXTLINE[bugprone-macro-parentheses] +#define NONZERO_OR_EMPTY(MACRO) ((MACRO + 0 != 0) || (0 - MACRO - 1 >= 0)) + +#if !defined(PYPY_VERSION) && \ + (defined(PY_VERSION_HEX) && PY_VERSION_HEX >= 0x030E0000 /* Python 3.14 */) && \ + (defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \ + NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT)) + +inline py::ssize_t GetPyInterpreterID() { + PyInterpreterState *interp = PyInterpreterState_Get(); + if (interp == nullptr) [[unlikely]] { + throw std::runtime_error("Failed to get the current Python interpreter state."); + } + return PyInterpreterState_GetID(interp); +} + +#else + +inline constexpr py::ssize_t GetPyInterpreterID() noexcept { + // Fallback for Python versions < 3.14 or when subinterpreter support is not available. + return 0; +} + +#endif diff --git a/include/optree/registry.h b/include/optree/registry.h index bf1eed70..0069fa5f 100644 --- a/include/optree/registry.h +++ b/include/optree/registry.h @@ -67,14 +67,6 @@ constexpr PyTreeKind kDeque = PyTreeKind::Deque; constexpr PyTreeKind kStructSequence = PyTreeKind::StructSequence; constexpr PyTreeKind kNumPyTreeKinds = PyTreeKind::NumKinds; -[[nodiscard]] inline ssize_t GetPyInterpreterID() { - PyInterpreterState *interp = PyInterpreterState_Get(); - if (interp == nullptr) [[unlikely]] { - throw std::runtime_error("Failed to get the current Python interpreter state."); - } - return PyInterpreterState_GetID(interp); -} - // Registry of custom node types. class PyTreeTypeRegistry { public: diff --git a/optree/_C.pyi b/optree/_C.pyi index feae03d8..6e4209bf 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -215,4 +215,5 @@ def set_dict_insertion_ordered( ) -> None: ... def get_num_interpreters_seen() -> int: ... def get_num_interpreters_alive() -> int: ... +def get_alive_interpreter_ids() -> set[int]: ... def get_registry_size(namespace: str | None = None) -> int: ... diff --git a/src/optree.cpp b/src/optree.cpp index b21f92c6..aaae3355 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -32,9 +32,6 @@ limitations under the License. # include #endif -// NOLINTNEXTLINE[bugprone-macro-parentheses] -#define NONZERO_OR_EMPTY(MACRO) ((MACRO + 0 != 0) || (0 - MACRO - 1 >= 0)) - namespace optree { py::module_ GetCxxModule(const std::optional &module) { @@ -574,7 +571,7 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] const ssize_t interpreter_id = GetPyInterpreterID(); { - const scoped_write_lock lock{PyTreeTypeRegistry::sm_mutex}; + const scoped_write_lock interp_lock{PyTreeTypeRegistry::sm_mutex}; ++PyTreeTypeRegistry::sm_num_interpreters_alive; ++PyTreeTypeRegistry::sm_num_interpreters_seen; } From fade81a6e906db936cba7a25735083e0b30a596e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Dec 2025 18:48:23 +0000 Subject: [PATCH 10/59] chore(pre-commit): [pre-commit.ci] autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.14.6 → v0.14.7](https://github.com/astral-sh/ruff-pre-commit/compare/v0.14.6...v0.14.7) - [github.com/pre-commit/mirrors-mypy: v1.18.2 → v1.19.0](https://github.com/pre-commit/mirrors-mypy/compare/v1.18.2...v1.19.0) --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cedc15ae..fd33e89f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,7 +38,7 @@ repos: hooks: - id: cpplint - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.6 + rev: v0.14.7 hooks: - id: ruff-check args: [--fix, --exit-non-zero-on-fix] @@ -50,7 +50,7 @@ repos: - id: codespell additional_dependencies: [".[toml]"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.18.2 + rev: v1.19.0 hooks: - id: mypy exclude: | From e7a2a89e0755a91de08611a84760ba536e8184bf Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 1 Dec 2025 22:52:45 +0800 Subject: [PATCH 11/59] refactor: simplify per-interpreter registry --- include/optree/registry.h | 39 +++-- src/optree.cpp | 19 +-- src/registry.cpp | 302 ++++++++++++++++++++------------------ 3 files changed, 190 insertions(+), 170 deletions(-) diff --git a/include/optree/registry.h b/include/optree/registry.h index 0069fa5f..291bd03d 100644 --- a/include/optree/registry.h +++ b/include/optree/registry.h @@ -21,6 +21,7 @@ limitations under the License. #include // std::shared_ptr #include // std::optional, std::nullopt #include // std::string +#include // std::tuple #include // std::unordered_map #include // std::unordered_set #include // std::pair @@ -115,19 +116,20 @@ class PyTreeTypeRegistry { // Finds the custom type registration for `type`. Returns nullptr if none exists. template - static RegistrationPtr Lookup(const py::object &cls, const std::string ®istry_namespace); + [[nodiscard]] static RegistrationPtr Lookup(const py::object &cls, + const std::string ®istry_namespace); // Compute the node kind of a given Python object. template - static PyTreeKind GetKind(const py::handle &handle, - RegistrationPtr &custom, // NOLINT[runtime/references] - const std::string ®istry_namespace); + [[nodiscard]] static PyTreeKind GetKind(const py::handle &handle, + RegistrationPtr &custom, // NOLINT[runtime/references] + const std::string ®istry_namespace); friend void BuildModule(py::module_ &mod); // NOLINT[runtime/references] private: template - static PyTreeTypeRegistry *Singleton(); + [[nodiscard]] static PyTreeTypeRegistry *Singleton(); template static void RegisterImpl(const py::object &cls, @@ -137,18 +139,29 @@ class PyTreeTypeRegistry { const std::string ®istry_namespace); template - static RegistrationPtr UnregisterImpl(const py::object &cls, - const std::string ®istry_namespace); + [[nodiscard]] static RegistrationPtr UnregisterImpl(const py::object &cls, + const std::string ®istry_namespace); + + // Initialize the registry for a given interpreter. + void Init(); // Clear the registry on cleanup. - static void Clear(const std::optional &interpreter_id = std::nullopt); + static void Clear(); + + using RegistrationsMap = std::unordered_map; + using NamedRegistrationsMap = + std::unordered_map, RegistrationPtr>; + using BuiltinsTypesSet = std::unordered_set; + + template + [[nodiscard]] inline std::tuple + GetRegistrationsForInterpreterLocked() const; - std::unordered_map m_registrations{}; - std::unordered_map, RegistrationPtr> m_named_registrations{}; + bool m_none_is_leaf = false; + std::unordered_map m_registrations{}; + std::unordered_map m_named_registrations{}; - static inline std::unordered_set sm_builtins_types{}; - static inline std::unordered_map> - sm_interpreter_scoped_registered_types{}; + static inline std::unordered_map sm_builtins_types{}; static inline read_write_mutex sm_mutex{}; static inline ssize_t sm_num_interpreters_alive = 0; static inline ssize_t sm_num_interpreters_seen = 0; diff --git a/src/optree.cpp b/src/optree.cpp index aaae3355..f7e28027 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -161,8 +161,7 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] "get_num_interpreters_alive", []() -> size_t { const scoped_read_lock lock{PyTreeTypeRegistry::sm_mutex}; - EXPECT_EQ(py::ssize_t_cast( - PyTreeTypeRegistry::sm_interpreter_scoped_registered_types.size()), + EXPECT_EQ(py::ssize_t_cast(PyTreeTypeRegistry::sm_builtins_types.size()), PyTreeTypeRegistry::sm_num_interpreters_alive, "The number of alive interpreters should match the size of the " "interpreter-scoped registered types map."); @@ -173,14 +172,12 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] "get_alive_interpreter_ids", []() -> std::unordered_set { const scoped_read_lock lock{PyTreeTypeRegistry::sm_mutex}; - EXPECT_EQ(py::ssize_t_cast( - PyTreeTypeRegistry::sm_interpreter_scoped_registered_types.size()), + EXPECT_EQ(py::ssize_t_cast(PyTreeTypeRegistry::sm_builtins_types.size()), PyTreeTypeRegistry::sm_num_interpreters_alive, "The number of alive interpreters should match the size of the " "interpreter-scoped registered types map."); std::unordered_set ids; - for (const auto &[id, _] : - PyTreeTypeRegistry::sm_interpreter_scoped_registered_types) { + for (const auto &[id, _] : PyTreeTypeRegistry::sm_builtins_types) { ids.insert(id); } return ids; @@ -568,17 +565,15 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] PyType_Modified(PyTreeSpec_Type); PyType_Modified(PyTreeIter_Type); - const ssize_t interpreter_id = GetPyInterpreterID(); - { const scoped_write_lock interp_lock{PyTreeTypeRegistry::sm_mutex}; ++PyTreeTypeRegistry::sm_num_interpreters_alive; ++PyTreeTypeRegistry::sm_num_interpreters_seen; } - (void)PyTreeTypeRegistry::Singleton(); - (void)PyTreeTypeRegistry::Singleton(); - py::getattr(py::module_::import("atexit"), "register")(py::cpp_function( - [interpreter_id]() -> void { PyTreeTypeRegistry::Clear(interpreter_id); })); + PyTreeTypeRegistry::Singleton()->Init(); + PyTreeTypeRegistry::Singleton()->Init(); + py::getattr(py::module_::import("atexit"), + "register")(py::cpp_function(&PyTreeTypeRegistry::Clear)); } } // namespace optree diff --git a/src/registry.cpp b/src/registry.cpp index 6f67eeaa..6c73fdc6 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -19,10 +19,10 @@ limitations under the License. #include // std::optional #include // std::ostringstream #include // std::string +#include // std::tuple, std::make_tuple #include // std::remove_const_t #include // std::unordered_set #include // std::move, std::make_pair -#include // std::vector #include @@ -36,38 +36,7 @@ template return &(storage .call_once_and_store_result([]() -> PyTreeTypeRegistry { PyTreeTypeRegistry registry{}; - - const ssize_t interpreter_id = GetPyInterpreterID(); - auto &interpreter_scoped_registered_types = - sm_interpreter_scoped_registered_types.try_emplace(interpreter_id) - .first->second; - - const auto add_builtin_type = [®istry, - &interpreter_scoped_registered_types]( - const py::object &cls, - const PyTreeKind &kind) -> void { - auto registration = - std::make_shared>(); - registration->kind = kind; - registration->type = py::reinterpret_borrow(cls); - EXPECT_TRUE( - registry.m_registrations.emplace(cls, std::move(registration)).second, - "PyTree type " + PyRepr(cls) + - " is already registered in the global namespace."); - if (sm_builtins_types.emplace(cls).second) [[likely]] { - cls.inc_ref(); - interpreter_scoped_registered_types.emplace(cls); - } - }; - if constexpr (!NoneIsLeaf) { - add_builtin_type(PyNoneTypeObject, PyTreeKind::None); - } - add_builtin_type(PyTupleTypeObject, PyTreeKind::Tuple); - add_builtin_type(PyListTypeObject, PyTreeKind::List); - add_builtin_type(PyDictTypeObject, PyTreeKind::Dict); - add_builtin_type(PyOrderedDictTypeObject, PyTreeKind::OrderedDict); - add_builtin_type(PyDefaultDictTypeObject, PyTreeKind::DefaultDict); - add_builtin_type(PyDequeTypeObject, PyTreeKind::Deque); + registry.m_none_is_leaf = NoneIsLeaf; return registry; }) .get_stored()); @@ -76,26 +45,100 @@ template template PyTreeTypeRegistry *PyTreeTypeRegistry::Singleton(); template PyTreeTypeRegistry *PyTreeTypeRegistry::Singleton(); -ssize_t PyTreeTypeRegistry::Size(const std::optional ®istry_namespace) const { - const scoped_read_lock lock{sm_mutex}; +template +std::tuple +PyTreeTypeRegistry::GetRegistrationsForInterpreterLocked() const { + const ssize_t interpreter_id = GetPyInterpreterID(); + + EXPECT_NE( + m_registrations.find(interpreter_id), + m_registrations.end(), + "Interpreter ID " + std::to_string(interpreter_id) + " not found in `m_registrations`."); + EXPECT_NE(m_named_registrations.find(interpreter_id), + m_named_registrations.end(), + "Interpreter ID " + std::to_string(interpreter_id) + + " not found in `m_named_registrations`."); + EXPECT_NE( + sm_builtins_types.find(interpreter_id), + sm_builtins_types.end(), + "Interpreter ID " + std::to_string(interpreter_id) + " not found in `sm_builtins_types`."); + + const auto ®istrations = m_registrations.at(interpreter_id); + const auto &named_registrations = m_named_registrations.at(interpreter_id); + const auto &builtins_types = sm_builtins_types.at(interpreter_id); + + // NOLINTBEGIN[cppcoreguidelines-pro-type-const-cast] + return std::make_tuple(const_cast(®istrations), + const_cast(&named_registrations), + const_cast(&builtins_types)); + // NOLINTEND[cppcoreguidelines-pro-type-const-cast] +} + +void PyTreeTypeRegistry::Init() { + const scoped_write_lock lock{sm_mutex}; const ssize_t interpreter_id = GetPyInterpreterID(); - const auto &it = sm_interpreter_scoped_registered_types.find(interpreter_id); - const auto &interpreter_scoped_registered_types = - (it != sm_interpreter_scoped_registered_types.end() ? it->second - : std::unordered_set{}); - - ssize_t count = 0; - for (const auto &[type, _] : m_registrations) { - if (interpreter_scoped_registered_types.find(type) != - interpreter_scoped_registered_types.end()) [[likely]] { - ++count; + + EXPECT_EQ(m_registrations.find(interpreter_id), + m_registrations.end(), + "Interpreter ID " + std::to_string(interpreter_id) + + " is already initialized in `m_registrations`."); + EXPECT_EQ(m_named_registrations.find(interpreter_id), + m_named_registrations.end(), + "Interpreter ID " + std::to_string(interpreter_id) + + " is already initialized in `m_named_registrations`."); + if (!m_none_is_leaf) [[likely]] { + EXPECT_EQ(sm_builtins_types.find(interpreter_id), + sm_builtins_types.end(), + "Interpreter ID " + std::to_string(interpreter_id) + + " is already initialized in `sm_builtins_types`."); + } else { + EXPECT_NE(sm_builtins_types.find(interpreter_id), + sm_builtins_types.end(), + "Interpreter ID " + std::to_string(interpreter_id) + + " is not initialized in `sm_builtins_types`."); + } + + auto ®istrations = m_registrations.try_emplace(interpreter_id).first->second; + auto &named_registrations = m_named_registrations.try_emplace(interpreter_id).first->second; + auto &builtins_types = sm_builtins_types.try_emplace(interpreter_id).first->second; + + (void)named_registrations; // silence unused variable warning + + const auto add_builtin_type = + [®istrations, &builtins_types](const py::object &cls, const PyTreeKind &kind) -> void { + auto registration = std::make_shared>(); + registration->kind = kind; + registration->type = py::reinterpret_borrow(cls); + EXPECT_TRUE( + registrations.emplace(cls, std::move(registration)).second, + "PyTree type " + PyRepr(cls) + " is already registered in the global namespace."); + if (builtins_types.emplace(cls).second) [[likely]] { + cls.inc_ref(); } + }; + if (!m_none_is_leaf) [[likely]] { + add_builtin_type(PyNoneTypeObject, PyTreeKind::None); } - for (const auto &[named_type, _] : m_named_registrations) { - if (interpreter_scoped_registered_types.find(named_type.second) != - interpreter_scoped_registered_types.end() && - (!registry_namespace || named_type.first == *registry_namespace)) [[likely]] { + add_builtin_type(PyTupleTypeObject, PyTreeKind::Tuple); + add_builtin_type(PyListTypeObject, PyTreeKind::List); + add_builtin_type(PyDictTypeObject, PyTreeKind::Dict); + add_builtin_type(PyOrderedDictTypeObject, PyTreeKind::OrderedDict); + add_builtin_type(PyDefaultDictTypeObject, PyTreeKind::DefaultDict); + add_builtin_type(PyDequeTypeObject, PyTreeKind::Deque); +} + +ssize_t PyTreeTypeRegistry::Size(const std::optional ®istry_namespace) const { + const scoped_read_lock lock{sm_mutex}; + + const auto [registrations, named_registrations, _] = + GetRegistrationsForInterpreterLocked(); + + ssize_t count = py::ssize_t_cast(registrations->size()); + for (const auto &[named_type, _] : *named_registrations) { + if (!registry_namespace || named_type.first == *registry_namespace) [[likely]] { ++count; } } @@ -108,12 +151,16 @@ template const py::function &unflatten_func, const py::object &path_entry_type, const std::string ®istry_namespace) { - if (sm_builtins_types.find(cls) != sm_builtins_types.end()) [[unlikely]] { + PyTreeTypeRegistry * const registry = Singleton(); + + const auto [registrations, named_registrations, builtins_types] = + registry->GetRegistrationsForInterpreterLocked(); + + if (builtins_types->find(cls) != builtins_types->end()) [[unlikely]] { throw py::value_error("PyTree type " + PyRepr(cls) + " is a built-in type and cannot be re-registered."); } - PyTreeTypeRegistry * const registry = Singleton(); auto registration = std::make_shared>(); registration->kind = PyTreeKind::Custom; registration->type = py::reinterpret_borrow(cls); @@ -121,7 +168,7 @@ template registration->unflatten_func = py::reinterpret_borrow(unflatten_func); registration->path_entry_type = py::reinterpret_borrow(path_entry_type); if (registry_namespace.empty()) [[unlikely]] { - if (!registry->m_registrations.emplace(cls, std::move(registration)).second) [[unlikely]] { + if (!registrations->emplace(cls, std::move(registration)).second) [[unlikely]] { throw py::value_error("PyTree type " + PyRepr(cls) + " is already registered in the global namespace."); } @@ -143,8 +190,8 @@ template /*stack_level=*/2); } } else [[likely]] { - if (!registry->m_named_registrations - .emplace(std::make_pair(registry_namespace, cls), std::move(registration)) + if (!named_registrations + ->emplace(std::make_pair(registry_namespace, cls), std::move(registration)) .second) [[unlikely]] { std::ostringstream oss{}; oss << "PyTree type " << PyRepr(cls) << " is already registered in namespace " @@ -196,25 +243,25 @@ template flatten_func.inc_ref(); unflatten_func.inc_ref(); path_entry_type.inc_ref(); - - const ssize_t interpreter_id = GetPyInterpreterID(); - const auto &it = sm_interpreter_scoped_registered_types.try_emplace(interpreter_id).first; - it->second.emplace(cls); } template /*static*/ PyTreeTypeRegistry::RegistrationPtr PyTreeTypeRegistry::UnregisterImpl( const py::object &cls, const std::string ®istry_namespace) { - if (sm_builtins_types.find(cls) != sm_builtins_types.end()) [[unlikely]] { + PyTreeTypeRegistry * const registry = Singleton(); + + const auto [registrations, named_registrations, builtins_types] = + registry->GetRegistrationsForInterpreterLocked(); + + if (builtins_types->find(cls) != builtins_types->end()) [[unlikely]] { throw py::value_error("PyTree type " + PyRepr(cls) + " is a built-in type and cannot be unregistered."); } - PyTreeTypeRegistry * const registry = Singleton(); if (registry_namespace.empty()) [[unlikely]] { - const auto it = registry->m_registrations.find(cls); - if (it == registry->m_registrations.end()) [[unlikely]] { + const auto it = registrations->find(cls); + if (it == registrations->end()) [[unlikely]] { std::ostringstream oss{}; oss << "PyTree type " << PyRepr(cls) << " "; if (IsStructSequenceClass(cls)) [[unlikely]] { @@ -229,12 +276,11 @@ template throw py::value_error(oss.str()); } RegistrationPtr registration = it->second; - registry->m_registrations.erase(it); + registrations->erase(it); return registration; } else [[likely]] { - const auto named_it = - registry->m_named_registrations.find(std::make_pair(registry_namespace, cls)); - if (named_it == registry->m_named_registrations.end()) [[unlikely]] { + const auto named_it = named_registrations->find(std::make_pair(registry_namespace, cls)); + if (named_it == named_registrations->end()) [[unlikely]] { std::ostringstream oss{}; oss << "PyTree type " << PyRepr(cls) << " "; if (IsStructSequenceClass(cls)) [[unlikely]] { @@ -250,7 +296,7 @@ template throw py::value_error(oss.str()); } RegistrationPtr registration = named_it->second; - registry->m_named_registrations.erase(named_it); + named_registrations->erase(named_it); return registration; } } @@ -277,16 +323,19 @@ template const std::string ®istry_namespace) { const scoped_read_lock lock{sm_mutex}; - PyTreeTypeRegistry * const registry = Singleton(); + PyTreeTypeRegistry * const registry = PyTreeTypeRegistry::Singleton(); + + const auto [registrations, named_registrations, _] = + registry->GetRegistrationsForInterpreterLocked(); + if (!registry_namespace.empty()) [[unlikely]] { - const auto named_it = - registry->m_named_registrations.find(std::make_pair(registry_namespace, cls)); - if (named_it != registry->m_named_registrations.end()) [[likely]] { + const auto named_it = named_registrations->find(std::make_pair(registry_namespace, cls)); + if (named_it != named_registrations->end()) [[likely]] { return named_it->second; } } - const auto it = registry->m_registrations.find(cls); - return it != registry->m_registrations.end() ? it->second : nullptr; + const auto it = registrations->find(cls); + return it != registrations->end() ? it->second : nullptr; } template PyTreeTypeRegistry::RegistrationPtr PyTreeTypeRegistry::Lookup( @@ -331,23 +380,30 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( const std::string &); // NOLINTNEXTLINE[readability-function-cognitive-complexity] -/*static*/ void PyTreeTypeRegistry::Clear(const std::optional &interpreter_id) { +/*static*/ void PyTreeTypeRegistry::Clear() { const scoped_write_lock lock{sm_mutex}; + const ssize_t interpreter_id = GetPyInterpreterID(); + PyTreeTypeRegistry * const registry1 = PyTreeTypeRegistry::Singleton(); PyTreeTypeRegistry * const registry2 = PyTreeTypeRegistry::Singleton(); - EXPECT_LE(sm_builtins_types.size(), registry1->m_registrations.size()); - EXPECT_EQ(registry1->m_registrations.size(), registry2->m_registrations.size() + 1); - EXPECT_EQ(registry1->m_named_registrations.size(), registry2->m_named_registrations.size()); + const auto [registrations1, named_registrations1, builtins_types] = + registry1->GetRegistrationsForInterpreterLocked(); + const auto [registrations2, named_registrations2, _] = + registry2->GetRegistrationsForInterpreterLocked(); + + EXPECT_LE(builtins_types->size(), registrations1->size()); + EXPECT_EQ(registrations1->size(), registrations2->size() + 1); + EXPECT_EQ(named_registrations1->size(), named_registrations2->size()); #if defined(Py_DEBUG) - for (const auto &cls : sm_builtins_types) { - EXPECT_NE(registry1->m_registrations.find(cls), registry1->m_registrations.end()); + for (const auto &cls : builtins_types) { + EXPECT_NE(registrations1->find(cls), registrations1->end()); } - for (const auto &[cls2, registration2] : registry2->m_registrations) { - const auto it1 = registry1->m_registrations.find(cls2); - EXPECT_NE(it1, registry1->m_registrations.end()); + for (const auto &[cls2, registration2] : *registrations2) { + const auto it1 = registrations1->find(cls2); + EXPECT_NE(it1, registrations1->end()); const auto ®istration1 = it1->second; EXPECT_TRUE(registration1->type.is(registration2->type)); @@ -355,9 +411,9 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( EXPECT_TRUE(registration1->unflatten_func.is(registration2->unflatten_func)); EXPECT_TRUE(registration1->path_entry_type.is(registration2->path_entry_type)); } - for (const auto &[cls2, registration2] : registry2->m_named_registrations) { - const auto it1 = registry1->m_named_registrations.find(cls2); - EXPECT_NE(it1, registry1->m_named_registrations.end()); + for (const auto &[cls2, registration2] : *named_registrations2) { + const auto it1 = named_registrations1->find(cls2); + EXPECT_NE(it1, named_registrations1->end()); const auto ®istration1 = it1->second; EXPECT_TRUE(registration1->type.is(registration2->type)); @@ -367,78 +423,34 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( } #endif - EXPECT_EQ(py::ssize_t_cast(sm_interpreter_scoped_registered_types.size()), - sm_num_interpreters_alive); - - if (!interpreter_id) [[unlikely]] { - EXPECT_EQ(sm_num_interpreters_alive, - 1, - "interpreter_id must be provided when multiple interpreters are alive."); - } - - if (sm_num_interpreters_alive > 1) [[unlikely]] { - EXPECT_TRUE(interpreter_id, - "interpreter_id must be provided when multiple interpreters are alive."); - const auto it = sm_interpreter_scoped_registered_types.find(*interpreter_id); - if (it != sm_interpreter_scoped_registered_types.end()) [[likely]] { - for (const auto &cls : it->second) { - const auto reg_it1 = registry1->m_registrations.find(cls); - if (reg_it1 != registry1->m_registrations.end()) [[likely]] { - const auto registration1 = reg_it1->second; - registration1->type.dec_ref(); - registration1->flatten_func.dec_ref(); - registration1->unflatten_func.dec_ref(); - registration1->path_entry_type.dec_ref(); - registry1->m_registrations.erase(reg_it1); - } - std::vector> named_types_to_erase{}; - for (const auto &[named_type, _] : registry1->m_named_registrations) { - if (named_type.second.is(cls)) [[likely]] { - named_types_to_erase.emplace_back(named_type); - } - } - for (const auto &named_type : named_types_to_erase) { - const auto named_it1 = registry1->m_named_registrations.find(named_type); - EXPECT_NE(named_it1, registry1->m_named_registrations.end()); - const auto registration1 = named_it1->second; - registration1->type.dec_ref(); - registration1->flatten_func.dec_ref(); - registration1->unflatten_func.dec_ref(); - registration1->path_entry_type.dec_ref(); - registry1->m_named_registrations.erase(named_it1); - } - registry2->m_registrations.erase(cls); - for (const auto &named_type : named_types_to_erase) { - registry2->m_named_registrations.erase(named_type); - } - } - sm_interpreter_scoped_registered_types.erase(it); - } else [[unlikely]] { - INTERNAL_ERROR("Interpreter ID " + std::to_string(*interpreter_id) + - " not found in `sm_interpreter_scoped_registered_types`."); - } - --sm_num_interpreters_alive; - return; - } + EXPECT_EQ(py::ssize_t_cast(sm_builtins_types.size()), sm_num_interpreters_alive); - for (const auto &[_, registration] : registry1->m_registrations) { + for (const auto &[_, registration] : *registrations1) { registration->type.dec_ref(); registration->flatten_func.dec_ref(); registration->unflatten_func.dec_ref(); registration->path_entry_type.dec_ref(); } - for (const auto &[_, registration] : registry1->m_named_registrations) { + for (const auto &[_, registration] : *named_registrations1) { registration->type.dec_ref(); registration->flatten_func.dec_ref(); registration->unflatten_func.dec_ref(); registration->path_entry_type.dec_ref(); } - sm_builtins_types.clear(); - registry1->m_registrations.clear(); - registry1->m_named_registrations.clear(); - registry2->m_registrations.clear(); - registry2->m_named_registrations.clear(); + builtins_types->clear(); + registrations1->clear(); + named_registrations1->clear(); + registrations2->clear(); + named_registrations2->clear(); + + sm_builtins_types.erase(interpreter_id); + registry1->m_registrations.erase(interpreter_id); + registry1->m_named_registrations.erase(interpreter_id); + registry2->m_registrations.erase(interpreter_id); + registry2->m_named_registrations.erase(interpreter_id); + + --sm_num_interpreters_alive; } } // namespace optree From 1fd07c8cf0bef11a32116860ddede9880d7f368d Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 1 Dec 2025 23:32:33 +0800 Subject: [PATCH 12/59] fix: fix collections type for subinterpreters --- include/optree/pymacros.h | 5 +- include/optree/pytypes.h | 10 ++++ src/registry.cpp | 18 ++++--- tests/concurrent/test_subinterpreters.py | 64 ++++++++++++++++++++++++ 4 files changed, 87 insertions(+), 10 deletions(-) diff --git a/include/optree/pymacros.h b/include/optree/pymacros.h index 13841bf2..106fd4ee 100644 --- a/include/optree/pymacros.h +++ b/include/optree/pymacros.h @@ -17,6 +17,7 @@ limitations under the License. #pragma once +#include // std::int64_t #include // std::runtime_error #include @@ -102,7 +103,7 @@ Py_Declare_ID(n_unnamed_fields); // structseq.n_unnamed_fields (defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \ NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT)) -inline py::ssize_t GetPyInterpreterID() { +inline std::int64_t GetPyInterpreterID() { PyInterpreterState *interp = PyInterpreterState_Get(); if (interp == nullptr) [[unlikely]] { throw std::runtime_error("Failed to get the current Python interpreter state."); @@ -112,7 +113,7 @@ inline py::ssize_t GetPyInterpreterID() { #else -inline constexpr py::ssize_t GetPyInterpreterID() noexcept { +inline constexpr std::int64_t GetPyInterpreterID() noexcept { // Fallback for Python versions < 3.14 or when subinterpreter support is not available. return 0; } diff --git a/include/optree/pytypes.h b/include/optree/pytypes.h index b031929b..daf1a9f2 100644 --- a/include/optree/pytypes.h +++ b/include/optree/pytypes.h @@ -75,6 +75,15 @@ inline const py::object &ImportOrderedDict() { }) .get_stored(); } +#if !defined(PYPY_VERSION) && \ + (defined(PY_VERSION_HEX) && PY_VERSION_HEX >= 0x030E0000 /* Python 3.14 */) && \ + (defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \ + NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT)) +inline py::object ImportDefaultDict() { + return py::getattr(py::module_::import("collections"), "defaultdict"); +} +inline py::object ImportDeque() { return py::getattr(py::module_::import("collections"), "deque"); } +#else inline const py::object &ImportDefaultDict() { PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store storage; return storage @@ -90,6 +99,7 @@ inline const py::object &ImportDeque() { []() -> py::object { return py::getattr(py::module_::import("collections"), "deque"); }) .get_stored(); } +#endif inline Py_ALWAYS_INLINE py::ssize_t TupleGetSize(const py::handle &tuple) { return PyTuple_GET_SIZE(tuple.ptr()); diff --git a/src/registry.cpp b/src/registry.cpp index 6c73fdc6..6be4e436 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -50,7 +50,7 @@ std::tuple PyTreeTypeRegistry::GetRegistrationsForInterpreterLocked() const { - const ssize_t interpreter_id = GetPyInterpreterID(); + const auto interpreter_id = GetPyInterpreterID(); EXPECT_NE( m_registrations.find(interpreter_id), @@ -79,7 +79,7 @@ PyTreeTypeRegistry::GetRegistrationsForInterpreterLocked() const { void PyTreeTypeRegistry::Init() { const scoped_write_lock lock{sm_mutex}; - const ssize_t interpreter_id = GetPyInterpreterID(); + const auto interpreter_id = GetPyInterpreterID(); EXPECT_EQ(m_registrations.find(interpreter_id), m_registrations.end(), @@ -133,9 +133,11 @@ void PyTreeTypeRegistry::Init() { ssize_t PyTreeTypeRegistry::Size(const std::optional ®istry_namespace) const { const scoped_read_lock lock{sm_mutex}; - const auto [registrations, named_registrations, _] = + const auto [registrations, named_registrations, builtins_types] = GetRegistrationsForInterpreterLocked(); + (void)builtins_types; // silence unused variable warning + ssize_t count = py::ssize_t_cast(registrations->size()); for (const auto &[named_type, _] : *named_registrations) { if (!registry_namespace || named_type.first == *registry_namespace) [[likely]] { @@ -151,7 +153,7 @@ template const py::function &unflatten_func, const py::object &path_entry_type, const std::string ®istry_namespace) { - PyTreeTypeRegistry * const registry = Singleton(); + const PyTreeTypeRegistry * const registry = Singleton(); const auto [registrations, named_registrations, builtins_types] = registry->GetRegistrationsForInterpreterLocked(); @@ -249,7 +251,7 @@ template /*static*/ PyTreeTypeRegistry::RegistrationPtr PyTreeTypeRegistry::UnregisterImpl( const py::object &cls, const std::string ®istry_namespace) { - PyTreeTypeRegistry * const registry = Singleton(); + const PyTreeTypeRegistry * const registry = Singleton(); const auto [registrations, named_registrations, builtins_types] = registry->GetRegistrationsForInterpreterLocked(); @@ -323,7 +325,7 @@ template const std::string ®istry_namespace) { const scoped_read_lock lock{sm_mutex}; - PyTreeTypeRegistry * const registry = PyTreeTypeRegistry::Singleton(); + const PyTreeTypeRegistry * const registry = PyTreeTypeRegistry::Singleton(); const auto [registrations, named_registrations, _] = registry->GetRegistrationsForInterpreterLocked(); @@ -383,7 +385,7 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( /*static*/ void PyTreeTypeRegistry::Clear() { const scoped_write_lock lock{sm_mutex}; - const ssize_t interpreter_id = GetPyInterpreterID(); + const auto interpreter_id = GetPyInterpreterID(); PyTreeTypeRegistry * const registry1 = PyTreeTypeRegistry::Singleton(); PyTreeTypeRegistry * const registry2 = PyTreeTypeRegistry::Singleton(); @@ -398,7 +400,7 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( EXPECT_EQ(named_registrations1->size(), named_registrations2->size()); #if defined(Py_DEBUG) - for (const auto &cls : builtins_types) { + for (const auto &cls : *builtins_types) { EXPECT_NE(registrations1->find(cls), registrations1->end()); } for (const auto &[cls2, registration2] : *registrations2) { diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py index 0133f14c..760670ac 100644 --- a/tests/concurrent/test_subinterpreters.py +++ b/tests/concurrent/test_subinterpreters.py @@ -80,10 +80,65 @@ def concurrent_run(func, /, *args, **kwargs): run(object) # warm-up +def check_module_importable(): + import collections + import time + + import optree._C + + if optree._C.get_registry_size() != 8: + raise RuntimeError('registry size mismatch') + + tree = { + 'b': [2, (3, 4)], + 'a': 1, + 'c': collections.OrderedDict( + f=None, + d=5, + e=time.struct_time([6] + [None] * (time.struct_time.n_sequence_fields - 1)), + ), + 'g': collections.defaultdict(list, h=collections.deque([7, 8, 9], maxlen=10)), + } + + flat, spec = optree._C.flatten(tree) + reconstructed = spec.unflatten(flat) + if reconstructed != tree: + raise RuntimeError('unflatten/flatten mismatch') + if spec.num_leaves != 9: + raise RuntimeError(f'num_leaves mismatch: ({flat}, {spec})') + if flat != [1, 2, 3, 4, 5, 6, 7, 8, 9]: + raise RuntimeError(f'flattened leaves mismatch: ({flat}, {spec})') + + return ( + id(type(None)), + id(tuple), + id(list), + id(dict), + id(collections.OrderedDict), + ) + + def test_import(): + import collections + + expected = ( + id(type(None)), + id(tuple), + id(list), + id(dict), + id(collections.OrderedDict), + ) + + assert run(check_module_importable) == expected + for _ in range(random.randint(5, 10)): with contextlib.closing(interpreters.create()) as subinterpreter: subinterpreter.exec('import optree') + with contextlib.closing(interpreters.create()) as subinterpreter: + assert subinterpreter.call(check_module_importable) == expected + + for actual in concurrent_run(check_module_importable): + assert actual == expected with contextlib.ExitStack() as stack: subinterpreters = [ @@ -93,3 +148,12 @@ def test_import(): random.shuffle(subinterpreters) for subinterpreter in subinterpreters: subinterpreter.exec('import optree') + + with contextlib.ExitStack() as stack: + subinterpreters = [ + stack.enter_context(contextlib.closing(interpreters.create())) + for _ in range(random.randint(5, 10)) + ] + random.shuffle(subinterpreters) + for subinterpreter in subinterpreters: + assert subinterpreter.call(check_module_importable) == expected From 209801d48b8af654615352891aec73c3e9067772 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 2 Dec 2025 13:13:53 +0800 Subject: [PATCH 13/59] style: add `[[nodiscard]]` attributes --- include/optree/pymacros.h | 6 +- include/optree/registry.h | 6 +- include/optree/stdutils.h | 2 +- include/optree/synchronization.h | 2 +- include/optree/treespec.h | 110 ++++++++++++++++--------------- src/registry.cpp | 3 +- 6 files changed, 68 insertions(+), 61 deletions(-) diff --git a/include/optree/pymacros.h b/include/optree/pymacros.h index 106fd4ee..e6d10c06 100644 --- a/include/optree/pymacros.h +++ b/include/optree/pymacros.h @@ -62,7 +62,7 @@ inline constexpr Py_ALWAYS_INLINE bool Py_IsConstant(PyObject *x) noexcept { #define Py_Declare_ID(name) \ namespace { \ - inline PyObject *Py_ID_##name() { \ + [[nodiscard]] inline PyObject *Py_ID_##name() { \ PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store storage; \ return storage \ .call_once_and_store_result([]() -> PyObject * { \ @@ -103,7 +103,7 @@ Py_Declare_ID(n_unnamed_fields); // structseq.n_unnamed_fields (defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \ NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT)) -inline std::int64_t GetPyInterpreterID() { +[[nodiscard]] inline std::int64_t GetPyInterpreterID() { PyInterpreterState *interp = PyInterpreterState_Get(); if (interp == nullptr) [[unlikely]] { throw std::runtime_error("Failed to get the current Python interpreter state."); @@ -113,7 +113,7 @@ inline std::int64_t GetPyInterpreterID() { #else -inline constexpr std::int64_t GetPyInterpreterID() noexcept { +[[nodiscard]] inline constexpr std::int64_t GetPyInterpreterID() noexcept { // Fallback for Python versions < 3.14 or when subinterpreter support is not available. return 0; } diff --git a/include/optree/registry.h b/include/optree/registry.h index 291bd03d..a1c6535e 100644 --- a/include/optree/registry.h +++ b/include/optree/registry.h @@ -100,11 +100,11 @@ class PyTreeTypeRegistry { using RegistrationPtr = std::shared_ptr; - // Gets the number of registered types. + // Get the number of registered types. [[nodiscard]] ssize_t Size( const std::optional ®istry_namespace = std::nullopt) const; - // Registers a new custom type. Objects of `cls` will be treated as container node types in + // Register a new custom type. Objects of `cls` will be treated as container node types in // PyTrees. static void Register(const py::object &cls, const py::function &flatten_func, @@ -114,7 +114,7 @@ class PyTreeTypeRegistry { static void Unregister(const py::object &cls, const std::string ®istry_namespace = ""); - // Finds the custom type registration for `type`. Returns nullptr if none exists. + // Find the custom type registration for `type`. Returns nullptr if none exists. template [[nodiscard]] static RegistrationPtr Lookup(const py::object &cls, const std::string ®istry_namespace); diff --git a/include/optree/stdutils.h b/include/optree/stdutils.h index 87dc7568..bb50110c 100644 --- a/include/optree/stdutils.h +++ b/include/optree/stdutils.h @@ -23,7 +23,7 @@ limitations under the License. #include "optree/pymacros.h" // Py_ALWAYS_INLINE template -inline Py_ALWAYS_INLINE std::vector reserved_vector(std::size_t size) { +[[nodiscard]] inline Py_ALWAYS_INLINE std::vector reserved_vector(std::size_t size) { std::vector v{}; v.reserve(size); return v; diff --git a/include/optree/synchronization.h b/include/optree/synchronization.h index b06c57ce..461e82ee 100644 --- a/include/optree/synchronization.h +++ b/include/optree/synchronization.h @@ -178,6 +178,6 @@ class scoped_critical_section2 { #endif template -inline Py_ALWAYS_INLINE T thread_safe_cast(const py::handle &handle) { +[[nodiscard]] inline Py_ALWAYS_INLINE T thread_safe_cast(const py::handle &handle) { return EVALUATE_WITH_LOCK_HELD(py::cast(handle), handle); } diff --git a/include/optree/treespec.h b/include/optree/treespec.h index 8572c7c8..fff08b8d 100644 --- a/include/optree/treespec.h +++ b/include/optree/treespec.h @@ -49,15 +49,15 @@ constexpr ssize_t MAX_RECURSION_DEPTH = 1000; #endif // Test whether the given object is a leaf node. -bool IsLeaf(const py::object &object, - const std::optional &leaf_predicate, - const bool &none_is_leaf = false, - const std::string ®istry_namespace = ""); +[[nodiscard]] bool IsLeaf(const py::object &object, + const std::optional &leaf_predicate, + const bool &none_is_leaf = false, + const std::string ®istry_namespace = ""); template -bool IsLeafImpl(const py::handle &handle, - const std::optional &leaf_predicate, - const std::string ®istry_namespace); +[[nodiscard]] bool IsLeafImpl(const py::handle &handle, + const std::optional &leaf_predicate, + const std::string ®istry_namespace); py::module_ GetCxxModule(const std::optional &module = std::nullopt); @@ -88,7 +88,7 @@ class PyTreeSpec { // Flatten a PyTree into a list of leaves and a PyTreeSpec. // Return references to the flattened objects, which might be temporary objects in the case of // custom PyType handlers. - static std::pair, std::unique_ptr> Flatten( + [[nodiscard]] static std::pair, std::unique_ptr> Flatten( const py::object &tree, const std::optional &leaf_predicate = std::nullopt, const bool &none_is_leaf = false, @@ -97,11 +97,12 @@ class PyTreeSpec { // Flatten a PyTree into a list of leaves with a list of paths and a PyTreeSpec. // Return references to the flattened objects, which might be temporary objects in the case of // custom PyType handlers. - static std::tuple, std::vector, std::unique_ptr> - FlattenWithPath(const py::object &tree, - const std::optional &leaf_predicate = std::nullopt, - const bool &none_is_leaf = false, - const std::string ®istry_namespace = ""); + [[nodiscard]] static std:: + tuple, std::vector, std::unique_ptr> + FlattenWithPath(const py::object &tree, + const std::optional &leaf_predicate = std::nullopt, + const bool &none_is_leaf = false, + const std::string ®istry_namespace = ""); // Return an unflattened PyTree given an iterable of leaves and a PyTreeSpec. [[nodiscard]] py::object Unflatten(const py::iterable &leaves) const; @@ -240,24 +241,26 @@ class PyTreeSpec { // Transform the object returned by `ToPickleable()` back to PyTreeSpec. // Used to implement `PyTreeSpec.__setstate__`. - static std::unique_ptr FromPickleable(const py::object &pickleable); + [[nodiscard]] static std::unique_ptr FromPickleable(const py::object &pickleable); // Make a PyTreeSpec representing a leaf node. - static std::unique_ptr MakeLeaf(const bool &none_is_leaf = false, - const std::string ®istry_namespace = ""); + [[nodiscard]] static std::unique_ptr MakeLeaf( + const bool &none_is_leaf = false, + const std::string ®istry_namespace = ""); // Make a PyTreeSpec representing a `None` node. - static std::unique_ptr MakeNone(const bool &none_is_leaf = false, - const std::string ®istry_namespace = ""); + [[nodiscard]] static std::unique_ptr MakeNone( + const bool &none_is_leaf = false, + const std::string ®istry_namespace = ""); // Make a PyTreeSpec out of a collection of PyTreeSpecs. - static std::unique_ptr MakeFromCollection( + [[nodiscard]] static std::unique_ptr MakeFromCollection( const py::object &object, const bool &none_is_leaf = false, const std::string ®istry_namespace = ""); // Check if should preserve the insertion order of the dictionary keys during flattening. - static inline Py_ALWAYS_INLINE bool IsDictInsertionOrdered( + [[nodiscard]] static inline Py_ALWAYS_INLINE bool IsDictInsertionOrdered( const std::string ®istry_namespace, const bool &inherit_global_namespace = true) { const scoped_read_lock lock{sm_is_dict_insertion_ordered_mutex}; @@ -331,55 +334,57 @@ class PyTreeSpec { std::string m_namespace{}; // Helper that returns the string representation of a node kind. - static std::string NodeKindToString(const Node &node); + [[nodiscard]] static std::string NodeKindToString(const Node &node); // Helper that manufactures an instance of a node given its children. - static py::object MakeNode(const Node &node, - const py::object children[], // NOLINT[hicpp-avoid-c-arrays] - const size_t &num_children); + [[nodiscard]] static py::object MakeNode( + const Node &node, + const py::object children[], // NOLINT[hicpp-avoid-c-arrays] + const size_t &num_children); // Helper that identifies the path entry class for a node. - static py::object GetPathEntryType(const Node &node); + [[nodiscard]] static py::object GetPathEntryType(const Node &node); // Recursive helper used to implement Flatten(). - bool FlattenInto(const py::handle &handle, - std::vector &leaves, // NOLINT[runtime/references] - const std::optional &leaf_predicate, - const bool &none_is_leaf, - const std::string ®istry_namespace); + [[nodiscard]] bool FlattenInto(const py::handle &handle, + std::vector &leaves, // NOLINT[runtime/references] + const std::optional &leaf_predicate, + const bool &none_is_leaf, + const std::string ®istry_namespace); template - bool FlattenIntoImpl(const py::handle &handle, - Vector &leaves, // NOLINT[runtime/references] - const ssize_t &depth, - const std::optional &leaf_predicate, - const std::string ®istry_namespace); + [[nodiscard]] bool FlattenIntoImpl(const py::handle &handle, + Vector &leaves, // NOLINT[runtime/references] + const ssize_t &depth, + const std::optional &leaf_predicate, + const std::string ®istry_namespace); // Recursive helper used to implement FlattenWithPath(). - bool FlattenIntoWithPath(const py::handle &handle, - std::vector &leaves, // NOLINT[runtime/references] - std::vector &paths, // NOLINT[runtime/references] - const std::optional &leaf_predicate, - const bool &none_is_leaf, - const std::string ®istry_namespace); + [[nodiscard]] bool FlattenIntoWithPath( + const py::handle &handle, + std::vector &leaves, // NOLINT[runtime/references] + std::vector &paths, // NOLINT[runtime/references] + const std::optional &leaf_predicate, + const bool &none_is_leaf, + const std::string ®istry_namespace); template - bool FlattenIntoWithPathImpl(const py::handle &handle, - LeafVector &leaves, // NOLINT[runtime/references] - PathVector &paths, // NOLINT[runtime/references] - Stack &stack, // NOLINT[runtime/references] - const ssize_t &depth, - const std::optional &leaf_predicate, - const std::string ®istry_namespace); + [[nodiscard]] bool FlattenIntoWithPathImpl(const py::handle &handle, + LeafVector &leaves, // NOLINT[runtime/references] + PathVector &paths, // NOLINT[runtime/references] + Stack &stack, // NOLINT[runtime/references] + const ssize_t &depth, + const std::optional &leaf_predicate, + const std::string ®istry_namespace); template - py::object UnflattenImpl(const Span &leaves) const; + [[nodiscard]] py::object UnflattenImpl(const Span &leaves) const; - static std::tuple BroadcastToCommonSuffixImpl( + [[nodiscard]] static std::tuple BroadcastToCommonSuffixImpl( std::vector &nodes, // NOLINT[runtime/references] const std::vector &traversal, const ssize_t &pos, @@ -409,8 +414,9 @@ class PyTreeSpec { [[nodiscard]] ssize_t HashValueImpl() const; template - static std::unique_ptr MakeFromCollectionImpl(const py::handle &handle, - std::string registry_namespace); + [[nodiscard]] static std::unique_ptr MakeFromCollectionImpl( + const py::handle &handle, + std::string registry_namespace); // Used in tp_traverse for GC support. static int PyTpTraverse(PyObject *self_base, visitproc visit, void *arg); diff --git a/src/registry.cpp b/src/registry.cpp index 6be4e436..9838870f 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -392,12 +392,13 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( const auto [registrations1, named_registrations1, builtins_types] = registry1->GetRegistrationsForInterpreterLocked(); - const auto [registrations2, named_registrations2, _] = + const auto [registrations2, named_registrations2, builtins_types_] = registry2->GetRegistrationsForInterpreterLocked(); EXPECT_LE(builtins_types->size(), registrations1->size()); EXPECT_EQ(registrations1->size(), registrations2->size() + 1); EXPECT_EQ(named_registrations1->size(), named_registrations2->size()); + EXPECT_EQ(builtins_types, builtins_types_); #if defined(Py_DEBUG) for (const auto &cls : *builtins_types) { From 02648ab9909ef89fe53202997b39915f13bd9e54 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 6 Dec 2025 14:06:49 +0800 Subject: [PATCH 14/59] chore(pre-commit): update pre-commit hooks --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fd33e89f..c7e154cb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: args: [--ignore-case] files: ^docs/source/spelling_wordlist\.txt$ - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v21.1.6 + rev: v21.1.7 hooks: - id: clang-format - repo: https://github.com/cpplint/cpplint @@ -38,7 +38,7 @@ repos: hooks: - id: cpplint - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.7 + rev: v0.14.8 hooks: - id: ruff-check args: [--fix, --exit-non-zero-on-fix] From a32c672be53c2eeaa51b377e09815ecad5c2118e Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 6 Dec 2025 14:17:16 +0800 Subject: [PATCH 15/59] chore: sync changes --- .github/workflows/build.yml | 4 +- .github/workflows/tests-with-pydebug.yml | 42 +++++++++----- .github/workflows/tests.yml | 71 +++++++++++++++++------- docs/requirements.txt | 4 +- pyproject.toml | 13 +++-- tests/helpers.py | 18 ++++-- tests/requirements.txt | 17 +++--- tests/test_typing.py | 4 +- 8 files changed, 115 insertions(+), 58 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index bb85665d..89605871 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -42,9 +42,11 @@ concurrency: cancel-in-progress: ${{ github.event_name == 'pull_request' }} env: + OPTREE_CXX_WERROR: "OFF" + _GLIBCXX_USE_CXX11_ABI: "1" + PYTHONUNBUFFERED: "1" PYTHON_TAG: "py3" # to be updated PYTHON_VERSION: "3" # to be updated - _GLIBCXX_USE_CXX11_ABI: "1" COLUMNS: "100" FORCE_COLOR: "1" CLICOLOR_FORCE: "1" diff --git a/.github/workflows/tests-with-pydebug.yml b/.github/workflows/tests-with-pydebug.yml index ee48a90f..722fca3a 100644 --- a/.github/workflows/tests-with-pydebug.yml +++ b/.github/workflows/tests-with-pydebug.yml @@ -348,10 +348,15 @@ jobs: "--junit-xml=junit-${{ env.PYTHON_TAG }}-${{ runner.os }}.xml" ) make test PYTESTOPTS="${PYTESTOPTS[*]}" - if [[ -n "$(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp")" ]]; then + + CORE_DUMP_FILES="$( + find . -type d -path "./venv" -prune \ + -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + )" + if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "::error::Coredump files found, indicating a crash during tests." >&2 echo "Coredump files:" >&2 - ls -alh $(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp") >&2 + ls -alh ${CORE_DUMP_FILES} >&2 exit 1 fi @@ -369,9 +374,13 @@ jobs: if: ${{ !cancelled() }} shell: bash run: | - if [[ -n "$(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp")" ]]; then + CORE_DUMP_FILES="$( + find . -type d -path "./venv" -prune \ + -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + )" + if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "Found core dumps:" - ls -alh $(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp") + ls -alh ${CORE_DUMP_FILES} BACKTRACE_COMMAND="" if [[ "${{ runner.os }}" == 'Linux' ]]; then echo "::group::Install GDB" @@ -392,18 +401,23 @@ jobs: fi if [[ -n "${BACKTRACE_COMMAND}" ]]; then echo "Collecting backtraces:" - find . -iname "core.*.[1-9]*" -exec bash -xc " - echo '::group::backtrace from: {}'; - ${BACKTRACE_COMMAND}; - echo '::endgroup::'; - " ';' - find . -iname "core_*.dmp" -exec bash -xc " - echo '::group::backtrace from: {}'; - ${BACKTRACE_COMMAND}; - echo '::endgroup::'; - " ';' + find . -type d -path "./venv" -prune \ + -o -iname "core.*.[1-9]*" \ + -exec bash -xc " + echo '::group::backtrace from: {}'; + ${BACKTRACE_COMMAND}; + echo '::endgroup::'; + " ';' + find . -type d -path "./venv" -prune \ + -o -iname "core_*.dmp" \ + -exec bash -xc " + echo '::group::backtrace from: {}'; + ${BACKTRACE_COMMAND}; + echo '::endgroup::'; + " ';' fi echo "::warning::Coredump files found, see backtraces above for details." >&2 + exit 1 fi - name: Upload coverage to Codecov diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 87cee51a..7cd1dda6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -209,14 +209,19 @@ jobs: popd rm -rf venv ) + if [[ "$?" -ne 0 ]]; then echo "::error::Failed to install with C++17." >&2 exit 1 fi - if [[ -n "$(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp")" ]]; then + CORE_DUMP_FILES="$( + find . -type d -path "./venv" -prune \ + -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + )" + if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "::error::Coredump files found, indicating a crash during tests." >&2 echo "Coredump files:" >&2 - ls -alh $(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp") >&2 + ls -alh ${CORE_DUMP_FILES} >&2 exit 1 fi @@ -251,6 +256,7 @@ jobs: "${{ env.PYTHON }}" ) fi + ( set -ex ${{ env.PYTHON }} -m venv venv @@ -262,14 +268,19 @@ jobs: popd rm -rf venv ) + if [[ "$?" -ne 0 ]]; then echo "::error::Failed to install with CMake from PyPI." >&2 exit 1 fi - if [[ -n "$(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp")" ]]; then + CORE_DUMP_FILES="$( + find . -type d -path "./venv" -prune \ + -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + )" + if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "::error::Coredump files found, indicating a crash during tests." >&2 echo "Coredump files:" >&2 - ls -alh $(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp") >&2 + ls -alh ${CORE_DUMP_FILES} >&2 exit 1 fi @@ -290,6 +301,7 @@ jobs: "${{ env.PYTHON }}" ) fi + ( set -ex ${{ env.PYTHON }} -m venv venv @@ -300,14 +312,19 @@ jobs: popd rm -rf venv ) + if [[ "$?" -ne 0 ]]; then echo "::error::Failed to install with CMake from PyPI (no system CMake)." >&2 exit 1 fi - if [[ -n "$(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp")" ]]; then + CORE_DUMP_FILES="$( + find . -type d -path "./venv" -prune \ + -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + )" + if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "::error::Coredump files found, indicating a crash during tests." >&2 echo "Coredump files:" >&2 - ls -alh $(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp") >&2 + ls -alh ${CORE_DUMP_FILES} >&2 exit 1 fi @@ -323,10 +340,15 @@ jobs: "--junit-xml=junit-${{ env.PYTHON_TAG }}-${{ runner.os }}.xml" ) make test PYTESTOPTS="${PYTESTOPTS[*]}" - if [[ -n "$(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp")" ]]; then + + CORE_DUMP_FILES="$( + find . -type d -path "./venv" -prune \ + -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + )" + if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "::error::Coredump files found, indicating a crash during tests." >&2 echo "Coredump files:" >&2 - ls -alh $(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp") >&2 + ls -alh ${CORE_DUMP_FILES} >&2 exit 1 fi @@ -344,9 +366,13 @@ jobs: if: ${{ !cancelled() }} shell: bash run: | - if [[ -n "$(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp")" ]]; then + CORE_DUMP_FILES="$( + find . -type d -path "./venv" -prune \ + -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + )" + if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "Found core dumps:" - ls -alh $(find . -iname "core.*.[1-9]*" -o -iname "core_*.dmp") + ls -alh ${CORE_DUMP_FILES} BACKTRACE_COMMAND="" if [[ "${{ runner.os }}" == 'Linux' ]]; then echo "::group::Install GDB" @@ -367,18 +393,23 @@ jobs: fi if [[ -n "${BACKTRACE_COMMAND}" ]]; then echo "Collecting backtraces:" - find . -iname "core.*.[1-9]*" -exec bash -xc " - echo '::group::backtrace from: {}'; - ${BACKTRACE_COMMAND}; - echo '::endgroup::'; - " ';' - find . -iname "core_*.dmp" -exec bash -xc " - echo '::group::backtrace from: {}'; - ${BACKTRACE_COMMAND}; - echo '::endgroup::'; - " ';' + find . -type d -path "./venv" -prune \ + -o -iname "core.*.[1-9]*" \ + -exec bash -xc " + echo '::group::backtrace from: {}'; + ${BACKTRACE_COMMAND}; + echo '::endgroup::'; + " ';' + find . -type d -path "./venv" -prune \ + -o -iname "core_*.dmp" \ + -exec bash -xc " + echo '::group::backtrace from: {}'; + ${BACKTRACE_COMMAND}; + echo '::endgroup::'; + " ';' fi echo "::warning::Coredump files found, see backtraces above for details." >&2 + exit 1 fi - name: Upload coverage to Codecov diff --git a/docs/requirements.txt b/docs/requirements.txt index 55e1da0a..85a0ae49 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -3,14 +3,14 @@ sphinx sphinx-autoapi sphinx-autobuild +sphinx-autodoc-typehints sphinx-copybutton sphinx-rtd-theme sphinxcontrib-bibtex -sphinx-autodoc-typehints docutils --extra-index-url https://download.pytorch.org/whl/cpu -jax[cpu] >= 0.4.6 +jax[cpu] numpy torch diff --git a/pyproject.toml b/pyproject.toml index e2b91265..61fd419f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,19 +58,20 @@ jax = ["jax"] numpy = ["numpy"] torch = ["torch"] lint = [ - "ruff", - "pylint[spelling]", - "mypy", + "cpplint", "doc8", + "mypy", + "pre-commit", "pyenchant", + "pylint[spelling]", + "ruff", "xdoctest", - "cpplint", - "pre-commit", ] test = [ "pytest", "pytest-cov", "covdefaults", + "xdoctest", "rich", # Pin to minimum for compatibility test "typing-extensions == 4.6.0; python_version < '3.13' and platform_system == 'Linux'", @@ -84,10 +85,10 @@ docs = [ "sphinx", "sphinx-autoapi", "sphinx-autobuild", + "sphinx-autodoc-typehints", "sphinx-copybutton", "sphinx-rtd-theme", "sphinxcontrib-bibtex", - "sphinx-autodoc-typehints", "docutils", "jax[cpu]", "numpy", diff --git a/tests/helpers.py b/tests/helpers.py index e81a103f..2da6a9b8 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -32,25 +32,31 @@ import pytest import optree -import optree._C +from optree._C import ( + PYBIND11_HAS_NATIVE_ENUM, + PYBIND11_HAS_SUBINTERPRETER_SUPPORT, + Py_DEBUG, + Py_GIL_DISABLED, +) from optree.registry import __GLOBAL_NAMESPACE as GLOBAL_NAMESPACE TEST_ROOT = Path(__file__).absolute().parent +_ = PYBIND11_HAS_NATIVE_ENUM +_ = PYBIND11_HAS_SUBINTERPRETER_SUPPORT + if sysconfig.get_config_var('Py_DEBUG') is None: - Py_DEBUG = hasattr(sys, 'gettotalrefcount') + assert Py_DEBUG == hasattr(sys, 'gettotalrefcount') else: - Py_DEBUG = bool(int(sysconfig.get_config_var('Py_DEBUG') or '0')) -assert Py_DEBUG == optree._C.Py_DEBUG + assert Py_DEBUG == bool(int(sysconfig.get_config_var('Py_DEBUG') or '0')) skipif_pydebug = pytest.mark.skipif( Py_DEBUG, reason='Py_DEBUG is enabled which causes too much overhead', ) -Py_GIL_DISABLED = bool(int(sysconfig.get_config_var('Py_GIL_DISABLED') or '0')) -assert Py_GIL_DISABLED == optree._C.Py_GIL_DISABLED +assert Py_GIL_DISABLED == bool(int(sysconfig.get_config_var('Py_GIL_DISABLED') or '0')) skipif_freethreading = pytest.mark.skipif( Py_GIL_DISABLED, reason='Py_GIL_DISABLED is set', diff --git a/tests/requirements.txt b/tests/requirements.txt index 4ea69f96..58bc3c84 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -3,16 +3,19 @@ pytest pytest-cov covdefaults -rich -ruff -pylint[spelling] -mypy -doc8 -pyenchant xdoctest +rich + cpplint +doc8 +mypy pre-commit +pyenchant +pylint[spelling] +ruff + +--extra-index-url https://download.pytorch.org/whl/cpu -jax[cpu] >= 0.4.6 +jax[cpu] numpy torch diff --git a/tests/test_typing.py b/tests/test_typing.py index 3a4400ad..4e428735 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -26,8 +26,8 @@ import pytest import optree -import optree._C from helpers import ( + PYBIND11_HAS_NATIVE_ENUM, CustomNamedTupleSubclass, CustomTuple, Py_GIL_DISABLED, @@ -72,7 +72,7 @@ class FakeStructSequence(tuple): def test_pytreekind_enum(): - if optree._C.PYBIND11_HAS_NATIVE_ENUM: + if PYBIND11_HAS_NATIVE_ENUM: all_kinds = list(optree.PyTreeKind) assert len(all_kinds) == optree.PyTreeKind.NUM_KINDS assert issubclass(optree.PyTreeKind, enum.IntEnum) From c026062462a97b831f73f0b645570bcdaf56ae47 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 6 Dec 2025 14:22:23 +0800 Subject: [PATCH 16/59] chore: sync changes --- include/optree/pymacros.h | 5 +- include/optree/registry.h | 19 +++--- include/optree/stdutils.h | 2 +- include/optree/synchronization.h | 2 +- include/optree/treespec.h | 110 ++++++++++++++++--------------- src/optree.cpp | 18 ++--- src/treespec/gc.cpp | 4 +- src/treespec/serialization.cpp | 2 +- 8 files changed, 84 insertions(+), 78 deletions(-) diff --git a/include/optree/pymacros.h b/include/optree/pymacros.h index 7f2ac9ab..490a6f2f 100644 --- a/include/optree/pymacros.h +++ b/include/optree/pymacros.h @@ -59,7 +59,7 @@ inline constexpr Py_ALWAYS_INLINE bool Py_IsConstant(PyObject *x) noexcept { #define Py_Declare_ID(name) \ namespace { \ - inline PyObject *Py_ID_##name() { \ + [[nodiscard]] inline PyObject *Py_ID_##name() { \ PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store storage; \ return storage \ .call_once_and_store_result([]() -> PyObject * { \ @@ -91,3 +91,6 @@ Py_Declare_ID(_asdict); // namedtuple._asdict Py_Declare_ID(n_fields); // structseq.n_fields Py_Declare_ID(n_sequence_fields); // structseq.n_sequence_fields Py_Declare_ID(n_unnamed_fields); // structseq.n_unnamed_fields + +// NOLINTNEXTLINE[bugprone-macro-parentheses] +#define NONZERO_OR_EMPTY(MACRO) ((MACRO + 0 != 0) || (0 - MACRO - 1 >= 0)) diff --git a/include/optree/registry.h b/include/optree/registry.h index d39c792e..709365b5 100644 --- a/include/optree/registry.h +++ b/include/optree/registry.h @@ -98,7 +98,7 @@ class PyTreeTypeRegistry { using RegistrationPtr = std::shared_ptr; - // Registers a new custom type. Objects of `cls` will be treated as container node types in + // Register a new custom type. Objects of `cls` will be treated as container node types in // PyTrees. static void Register(const py::object &cls, const py::function &flatten_func, @@ -108,21 +108,22 @@ class PyTreeTypeRegistry { static void Unregister(const py::object &cls, const std::string ®istry_namespace = ""); - // Finds the custom type registration for `type`. Returns nullptr if none exists. + // Find the custom type registration for `type`. Returns nullptr if none exists. template - static RegistrationPtr Lookup(const py::object &cls, const std::string ®istry_namespace); + [[nodiscard]] static RegistrationPtr Lookup(const py::object &cls, + const std::string ®istry_namespace); // Compute the node kind of a given Python object. template - static PyTreeKind GetKind(const py::handle &handle, - RegistrationPtr &custom, // NOLINT[runtime/references] - const std::string ®istry_namespace); + [[nodiscard]] static PyTreeKind GetKind(const py::handle &handle, + RegistrationPtr &custom, // NOLINT[runtime/references] + const std::string ®istry_namespace); friend void BuildModule(py::module_ &mod); // NOLINT[runtime/references] private: template - static PyTreeTypeRegistry *Singleton(); + [[nodiscard]] static PyTreeTypeRegistry *Singleton(); template static void RegisterImpl(const py::object &cls, @@ -132,8 +133,8 @@ class PyTreeTypeRegistry { const std::string ®istry_namespace); template - static RegistrationPtr UnregisterImpl(const py::object &cls, - const std::string ®istry_namespace); + [[nodiscard]] static RegistrationPtr UnregisterImpl(const py::object &cls, + const std::string ®istry_namespace); // Clear the registry on cleanup. static void Clear(); diff --git a/include/optree/stdutils.h b/include/optree/stdutils.h index 87dc7568..bb50110c 100644 --- a/include/optree/stdutils.h +++ b/include/optree/stdutils.h @@ -23,7 +23,7 @@ limitations under the License. #include "optree/pymacros.h" // Py_ALWAYS_INLINE template -inline Py_ALWAYS_INLINE std::vector reserved_vector(std::size_t size) { +[[nodiscard]] inline Py_ALWAYS_INLINE std::vector reserved_vector(std::size_t size) { std::vector v{}; v.reserve(size); return v; diff --git a/include/optree/synchronization.h b/include/optree/synchronization.h index b06c57ce..461e82ee 100644 --- a/include/optree/synchronization.h +++ b/include/optree/synchronization.h @@ -178,6 +178,6 @@ class scoped_critical_section2 { #endif template -inline Py_ALWAYS_INLINE T thread_safe_cast(const py::handle &handle) { +[[nodiscard]] inline Py_ALWAYS_INLINE T thread_safe_cast(const py::handle &handle) { return EVALUATE_WITH_LOCK_HELD(py::cast(handle), handle); } diff --git a/include/optree/treespec.h b/include/optree/treespec.h index 8572c7c8..fff08b8d 100644 --- a/include/optree/treespec.h +++ b/include/optree/treespec.h @@ -49,15 +49,15 @@ constexpr ssize_t MAX_RECURSION_DEPTH = 1000; #endif // Test whether the given object is a leaf node. -bool IsLeaf(const py::object &object, - const std::optional &leaf_predicate, - const bool &none_is_leaf = false, - const std::string ®istry_namespace = ""); +[[nodiscard]] bool IsLeaf(const py::object &object, + const std::optional &leaf_predicate, + const bool &none_is_leaf = false, + const std::string ®istry_namespace = ""); template -bool IsLeafImpl(const py::handle &handle, - const std::optional &leaf_predicate, - const std::string ®istry_namespace); +[[nodiscard]] bool IsLeafImpl(const py::handle &handle, + const std::optional &leaf_predicate, + const std::string ®istry_namespace); py::module_ GetCxxModule(const std::optional &module = std::nullopt); @@ -88,7 +88,7 @@ class PyTreeSpec { // Flatten a PyTree into a list of leaves and a PyTreeSpec. // Return references to the flattened objects, which might be temporary objects in the case of // custom PyType handlers. - static std::pair, std::unique_ptr> Flatten( + [[nodiscard]] static std::pair, std::unique_ptr> Flatten( const py::object &tree, const std::optional &leaf_predicate = std::nullopt, const bool &none_is_leaf = false, @@ -97,11 +97,12 @@ class PyTreeSpec { // Flatten a PyTree into a list of leaves with a list of paths and a PyTreeSpec. // Return references to the flattened objects, which might be temporary objects in the case of // custom PyType handlers. - static std::tuple, std::vector, std::unique_ptr> - FlattenWithPath(const py::object &tree, - const std::optional &leaf_predicate = std::nullopt, - const bool &none_is_leaf = false, - const std::string ®istry_namespace = ""); + [[nodiscard]] static std:: + tuple, std::vector, std::unique_ptr> + FlattenWithPath(const py::object &tree, + const std::optional &leaf_predicate = std::nullopt, + const bool &none_is_leaf = false, + const std::string ®istry_namespace = ""); // Return an unflattened PyTree given an iterable of leaves and a PyTreeSpec. [[nodiscard]] py::object Unflatten(const py::iterable &leaves) const; @@ -240,24 +241,26 @@ class PyTreeSpec { // Transform the object returned by `ToPickleable()` back to PyTreeSpec. // Used to implement `PyTreeSpec.__setstate__`. - static std::unique_ptr FromPickleable(const py::object &pickleable); + [[nodiscard]] static std::unique_ptr FromPickleable(const py::object &pickleable); // Make a PyTreeSpec representing a leaf node. - static std::unique_ptr MakeLeaf(const bool &none_is_leaf = false, - const std::string ®istry_namespace = ""); + [[nodiscard]] static std::unique_ptr MakeLeaf( + const bool &none_is_leaf = false, + const std::string ®istry_namespace = ""); // Make a PyTreeSpec representing a `None` node. - static std::unique_ptr MakeNone(const bool &none_is_leaf = false, - const std::string ®istry_namespace = ""); + [[nodiscard]] static std::unique_ptr MakeNone( + const bool &none_is_leaf = false, + const std::string ®istry_namespace = ""); // Make a PyTreeSpec out of a collection of PyTreeSpecs. - static std::unique_ptr MakeFromCollection( + [[nodiscard]] static std::unique_ptr MakeFromCollection( const py::object &object, const bool &none_is_leaf = false, const std::string ®istry_namespace = ""); // Check if should preserve the insertion order of the dictionary keys during flattening. - static inline Py_ALWAYS_INLINE bool IsDictInsertionOrdered( + [[nodiscard]] static inline Py_ALWAYS_INLINE bool IsDictInsertionOrdered( const std::string ®istry_namespace, const bool &inherit_global_namespace = true) { const scoped_read_lock lock{sm_is_dict_insertion_ordered_mutex}; @@ -331,55 +334,57 @@ class PyTreeSpec { std::string m_namespace{}; // Helper that returns the string representation of a node kind. - static std::string NodeKindToString(const Node &node); + [[nodiscard]] static std::string NodeKindToString(const Node &node); // Helper that manufactures an instance of a node given its children. - static py::object MakeNode(const Node &node, - const py::object children[], // NOLINT[hicpp-avoid-c-arrays] - const size_t &num_children); + [[nodiscard]] static py::object MakeNode( + const Node &node, + const py::object children[], // NOLINT[hicpp-avoid-c-arrays] + const size_t &num_children); // Helper that identifies the path entry class for a node. - static py::object GetPathEntryType(const Node &node); + [[nodiscard]] static py::object GetPathEntryType(const Node &node); // Recursive helper used to implement Flatten(). - bool FlattenInto(const py::handle &handle, - std::vector &leaves, // NOLINT[runtime/references] - const std::optional &leaf_predicate, - const bool &none_is_leaf, - const std::string ®istry_namespace); + [[nodiscard]] bool FlattenInto(const py::handle &handle, + std::vector &leaves, // NOLINT[runtime/references] + const std::optional &leaf_predicate, + const bool &none_is_leaf, + const std::string ®istry_namespace); template - bool FlattenIntoImpl(const py::handle &handle, - Vector &leaves, // NOLINT[runtime/references] - const ssize_t &depth, - const std::optional &leaf_predicate, - const std::string ®istry_namespace); + [[nodiscard]] bool FlattenIntoImpl(const py::handle &handle, + Vector &leaves, // NOLINT[runtime/references] + const ssize_t &depth, + const std::optional &leaf_predicate, + const std::string ®istry_namespace); // Recursive helper used to implement FlattenWithPath(). - bool FlattenIntoWithPath(const py::handle &handle, - std::vector &leaves, // NOLINT[runtime/references] - std::vector &paths, // NOLINT[runtime/references] - const std::optional &leaf_predicate, - const bool &none_is_leaf, - const std::string ®istry_namespace); + [[nodiscard]] bool FlattenIntoWithPath( + const py::handle &handle, + std::vector &leaves, // NOLINT[runtime/references] + std::vector &paths, // NOLINT[runtime/references] + const std::optional &leaf_predicate, + const bool &none_is_leaf, + const std::string ®istry_namespace); template - bool FlattenIntoWithPathImpl(const py::handle &handle, - LeafVector &leaves, // NOLINT[runtime/references] - PathVector &paths, // NOLINT[runtime/references] - Stack &stack, // NOLINT[runtime/references] - const ssize_t &depth, - const std::optional &leaf_predicate, - const std::string ®istry_namespace); + [[nodiscard]] bool FlattenIntoWithPathImpl(const py::handle &handle, + LeafVector &leaves, // NOLINT[runtime/references] + PathVector &paths, // NOLINT[runtime/references] + Stack &stack, // NOLINT[runtime/references] + const ssize_t &depth, + const std::optional &leaf_predicate, + const std::string ®istry_namespace); template - py::object UnflattenImpl(const Span &leaves) const; + [[nodiscard]] py::object UnflattenImpl(const Span &leaves) const; - static std::tuple BroadcastToCommonSuffixImpl( + [[nodiscard]] static std::tuple BroadcastToCommonSuffixImpl( std::vector &nodes, // NOLINT[runtime/references] const std::vector &traversal, const ssize_t &pos, @@ -409,8 +414,9 @@ class PyTreeSpec { [[nodiscard]] ssize_t HashValueImpl() const; template - static std::unique_ptr MakeFromCollectionImpl(const py::handle &handle, - std::string registry_namespace); + [[nodiscard]] static std::unique_ptr MakeFromCollectionImpl( + const py::handle &handle, + std::string registry_namespace); // Used in tp_traverse for GC support. static int PyTpTraverse(PyObject *self_base, visitproc visit, void *arg); diff --git a/src/optree.cpp b/src/optree.cpp index dfa27e43..1cfbc9bc 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -111,11 +111,9 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] return f'0x{self:08X}' BUILDTIME_METADATA.update( - **{ - name: HexInt(value) - for name, value in BUILDTIME_METADATA.items() - if name.endswith('_HEX') and isinstance(value, int) - }, + (name, HexInt(value)) + for name, value in BUILDTIME_METADATA.items() + if name.endswith('_HEX') and isinstance(value, int) ) BUILDTIME_METADATA = types.MappingProxyType(BUILDTIME_METADATA) @@ -277,10 +275,8 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] #endif auto * const PyTreeKind_Type = reinterpret_cast(PyTreeKindTypeObject.ptr()); PyTreeKind_Type->tp_name = "optree.PyTreeKind"; - py::setattr(PyTreeKindTypeObject.ptr(), Py_Get_ID(__module__), Py_Get_ID(optree)); - py::setattr(PyTreeKindTypeObject.ptr(), - "NUM_KINDS", - py::int_(py::ssize_t(PyTreeKind::NumKinds))); + py::setattr(PyTreeKindTypeObject, Py_Get_ID(__module__), Py_Get_ID(optree)); + py::setattr(PyTreeKindTypeObject, "NUM_KINDS", py::int_(py::ssize_t(PyTreeKind::NumKinds))); auto PyTreeSpecTypeObject = #if defined(PYBIND11_HAS_INTERNALS_WITH_SMART_HOLDER_SUPPORT) @@ -303,7 +299,7 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] py::module_local()); auto * const PyTreeSpec_Type = reinterpret_cast(PyTreeSpecTypeObject.ptr()); PyTreeSpec_Type->tp_name = "optree.PyTreeSpec"; - py::setattr(PyTreeSpecTypeObject.ptr(), Py_Get_ID(__module__), Py_Get_ID(optree)); + py::setattr(PyTreeSpecTypeObject, Py_Get_ID(__module__), Py_Get_ID(optree)); PyTreeSpecTypeObject .def("unflatten", @@ -501,7 +497,7 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] py::module_local()); auto * const PyTreeIter_Type = reinterpret_cast(PyTreeIterTypeObject.ptr()); PyTreeIter_Type->tp_name = "optree.PyTreeIter"; - py::setattr(PyTreeIterTypeObject.ptr(), Py_Get_ID(__module__), Py_Get_ID(optree)); + py::setattr(PyTreeIterTypeObject, Py_Get_ID(__module__), Py_Get_ID(optree)); PyTreeIterTypeObject .def(py::init, bool, std::string>(), diff --git a/src/treespec/gc.cpp b/src/treespec/gc.cpp index 025e6178..7cdb52c2 100644 --- a/src/treespec/gc.cpp +++ b/src/treespec/gc.cpp @@ -62,8 +62,8 @@ namespace optree { return 0; } auto &self = thread_safe_cast(py::handle{self_base}); - for (const auto &pair : self.m_agenda) { - Py_VISIT(pair.first.ptr()); + for (const auto &[obj, _] : self.m_agenda) { + Py_VISIT(obj.ptr()); } Py_VISIT(self.m_root.ptr()); return 0; diff --git a/src/treespec/serialization.cpp b/src/treespec/serialization.cpp index 99af5beb..b80770c9 100644 --- a/src/treespec/serialization.cpp +++ b/src/treespec/serialization.cpp @@ -297,7 +297,7 @@ py::object PyTreeSpec::ToPickleable() const { ssize_t i = 0; for (const auto &node : m_traversal) { const scoped_critical_section2 cs{ - node.custom != nullptr ? py::handle{node.custom->type.ptr()} : py::handle{}, + node.custom != nullptr ? py::handle{node.custom->type} : py::handle{}, node.node_data}; TupleSetItem(node_states, i++, From 3ba5f292b554d0a2376a11256f58c2cf3c073609 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 6 Dec 2025 14:26:01 +0800 Subject: [PATCH 17/59] chore: sync changes --- .../test_threading.py} | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) rename tests/{test_concurrent.py => concurrent/test_threading.py} (95%) diff --git a/tests/test_concurrent.py b/tests/concurrent/test_threading.py similarity index 95% rename from tests/test_concurrent.py rename to tests/concurrent/test_threading.py index 1a74a5af..2d03897f 100644 --- a/tests/test_concurrent.py +++ b/tests/concurrent/test_threading.py @@ -55,8 +55,16 @@ atexit.register(EXECUTOR.shutdown) -def concurrent_run(func): - futures = [EXECUTOR.submit(func) for _ in range(NUM_FUTURES)] +def run(func, /, *args, **kwargs): + future = EXECUTOR.submit(func, *args, **kwargs) + exception = future.exception() + if exception is not None: + raise exception + return future.result() + + +def concurrent_run(func, /, *args, **kwargs): + futures = [EXECUTOR.submit(func, *args, **kwargs) for _ in range(NUM_FUTURES)] future2index = {future: i for i, future in enumerate(futures)} completed_futures = sorted(as_completed(futures), key=future2index.get) first_exception = next(filter(None, (future.exception() for future in completed_futures)), None) @@ -65,7 +73,7 @@ def concurrent_run(func): return [future.result() for future in completed_futures] -concurrent_run(object) # warm-up +run(object) # warm-up @parametrize( @@ -92,7 +100,7 @@ def test_fn(): for result in concurrent_run(test_fn): assert result == expected - for result in concurrent_run(lambda: optree.tree_unflatten(treespec, leaves)): + for result in concurrent_run(optree.tree_unflatten, treespec, leaves): assert result == tree @@ -353,7 +361,7 @@ def test_tree_iter_thread_safe( namespace=namespace, ) - results = concurrent_run(lambda: list(it)) + results = concurrent_run(list, it) for seq in results: assert sorted(seq) == seq assert sorted(itertools.chain.from_iterable(results)) == list(range(num_leaves)) From e276ec39994d1a571d66544cf1ceb491aaefb88e Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 8 Dec 2025 12:27:50 +0800 Subject: [PATCH 18/59] chore: set macro `OPTREE_HAS_SUBINTERPRETER_SUPPORT` --- include/optree/pymacros.h | 7 +++++++ include/optree/pytypes.h | 20 ++++++++++---------- optree/_C.pyi | 1 + src/optree.cpp | 8 ++++++-- tests/concurrent/test_subinterpreters.py | 4 ++-- tests/helpers.py | 4 ++-- 6 files changed, 28 insertions(+), 16 deletions(-) diff --git a/include/optree/pymacros.h b/include/optree/pymacros.h index a5525742..545bc523 100644 --- a/include/optree/pymacros.h +++ b/include/optree/pymacros.h @@ -86,6 +86,9 @@ Py_Declare_ID(__qualname__); // type.__qualname__ Py_Declare_ID(__name__); // type.__name__ Py_Declare_ID(sort); // list.sort Py_Declare_ID(copy); // dict.copy +Py_Declare_ID(OrderedDict); // OrderedDict +Py_Declare_ID(defaultdict); // defaultdict +Py_Declare_ID(deque); // deque Py_Declare_ID(default_factory); // defaultdict.default_factory Py_Declare_ID(maxlen); // deque.maxlen Py_Declare_ID(_fields); // namedtuple._fields @@ -103,6 +106,8 @@ Py_Declare_ID(n_unnamed_fields); // structseq.n_unnamed_fields (defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \ NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT)) +# define OPTREE_HAS_SUBINTERPRETER_SUPPORT 1 + [[nodiscard]] inline std::int64_t GetPyInterpreterID() { PyInterpreterState *interp = PyInterpreterState_Get(); if (interp == nullptr) [[unlikely]] { @@ -117,6 +122,8 @@ Py_Declare_ID(n_unnamed_fields); // structseq.n_unnamed_fields #else +# undef OPTREE_HAS_SUBINTERPRETER_SUPPORT + [[nodiscard]] inline constexpr std::int64_t GetPyInterpreterID() noexcept { // Fallback for Python versions < 3.14 or when subinterpreter support is not available. return 0; diff --git a/include/optree/pytypes.h b/include/optree/pytypes.h index daf1a9f2..b3cfbabd 100644 --- a/include/optree/pytypes.h +++ b/include/optree/pytypes.h @@ -71,32 +71,32 @@ inline const py::object &ImportOrderedDict() { PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store storage; return storage .call_once_and_store_result([]() -> py::object { - return py::getattr(py::module_::import("collections"), "OrderedDict"); + return py::getattr(py::module_::import("collections"), Py_Get_ID(OrderedDict)); }) .get_stored(); } -#if !defined(PYPY_VERSION) && \ - (defined(PY_VERSION_HEX) && PY_VERSION_HEX >= 0x030E0000 /* Python 3.14 */) && \ - (defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \ - NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT)) +#if defined(OPTREE_HAS_SUBINTERPRETER_SUPPORT) inline py::object ImportDefaultDict() { - return py::getattr(py::module_::import("collections"), "defaultdict"); + return py::getattr(py::module_::import("collections"), Py_Get_ID(defaultdict)); +} +inline py::object ImportDeque() { + return py::getattr(py::module_::import("collections"), Py_Get_ID(deque)); } -inline py::object ImportDeque() { return py::getattr(py::module_::import("collections"), "deque"); } #else inline const py::object &ImportDefaultDict() { PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store storage; return storage .call_once_and_store_result([]() -> py::object { - return py::getattr(py::module_::import("collections"), "defaultdict"); + return py::getattr(py::module_::import("collections"), Py_Get_ID(defaultdict)); }) .get_stored(); } inline const py::object &ImportDeque() { PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store storage; return storage - .call_once_and_store_result( - []() -> py::object { return py::getattr(py::module_::import("collections"), "deque"); }) + .call_once_and_store_result([]() -> py::object { + return py::getattr(py::module_::import("collections"), Py_Get_ID(deque)); + }) .get_stored(); } #endif diff --git a/optree/_C.pyi b/optree/_C.pyi index cb1d58e5..854f8c76 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -53,6 +53,7 @@ PYBIND11_HAS_NATIVE_ENUM: Final[bool] PYBIND11_HAS_INTERNALS_WITH_SMART_HOLDER_SUPPORT: Final[bool] PYBIND11_HAS_SUBINTERPRETER_SUPPORT: Final[bool] GLIBCXX_USE_CXX11_ABI: Final[bool] +OPTREE_HAS_SUBINTERPRETER_SUPPORT: Final[bool] @final class InternalError(SystemError): ... diff --git a/src/optree.cpp b/src/optree.cpp index 5c41ec63..aebc7dd2 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -96,6 +96,11 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] #else BUILDTIME_METADATA["GLIBCXX_USE_CXX11_ABI"] = py::bool_(false); #endif +#if defined(OPTREE_HAS_SUBINTERPRETER_SUPPORT) + BUILDTIME_METADATA["OPTREE_HAS_SUBINTERPRETER_SUPPORT"] = py::bool_(true); +#else + BUILDTIME_METADATA["OPTREE_HAS_SUBINTERPRETER_SUPPORT"] = py::bool_(false); +#endif mod.attr("BUILDTIME_METADATA") = std::move(BUILDTIME_METADATA); py::exec( @@ -583,8 +588,7 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] // NOLINTBEGIN[cppcoreguidelines-pro-bounds-pointer-arithmetic,cppcoreguidelines-pro-type-vararg] #if PYBIND11_VERSION_HEX >= 0x020D00F0 // pybind11 2.13.0 -# if defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \ - NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) +# if defined(OPTREE_HAS_SUBINTERPRETER_SUPPORT) PYBIND11_MODULE(_C, mod, py::mod_gil_not_used(), py::multiple_interpreters::per_interpreter_gil()) # else PYBIND11_MODULE(_C, mod, py::mod_gil_not_used()) diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py index 760670ac..48b8d526 100644 --- a/tests/concurrent/test_subinterpreters.py +++ b/tests/concurrent/test_subinterpreters.py @@ -23,7 +23,7 @@ from helpers import ( ANDROID, IOS, - PYBIND11_HAS_SUBINTERPRETER_SUPPORT, + OPTREE_HAS_SUBINTERPRETER_SUPPORT, PYPY, WASM, Py_DEBUG, @@ -38,7 +38,7 @@ or ANDROID or sys.version_info < (3, 14) or not getattr(sys.implementation, 'supports_isolated_interpreters', False) - or not PYBIND11_HAS_SUBINTERPRETER_SUPPORT + or not OPTREE_HAS_SUBINTERPRETER_SUPPORT ): pytest.skip('Test for CPython 3.14+ only', allow_module_level=True) diff --git a/tests/helpers.py b/tests/helpers.py index ebe500bb..c004ec53 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -33,8 +33,8 @@ import optree from optree._C import ( + OPTREE_HAS_SUBINTERPRETER_SUPPORT, PYBIND11_HAS_NATIVE_ENUM, - PYBIND11_HAS_SUBINTERPRETER_SUPPORT, Py_DEBUG, Py_GIL_DISABLED, get_registry_size, @@ -51,7 +51,7 @@ assert INITIAL_REGISTRY_SIZE + 2 == len(NODETYPE_REGISTRY) _ = PYBIND11_HAS_NATIVE_ENUM -_ = PYBIND11_HAS_SUBINTERPRETER_SUPPORT +_ = OPTREE_HAS_SUBINTERPRETER_SUPPORT if sysconfig.get_config_var('Py_DEBUG') is None: assert Py_DEBUG == hasattr(sys, 'gettotalrefcount') From 86aa6f6bef19ccc28f29dd94240f085d0e14efd6 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 8 Dec 2025 12:32:13 +0800 Subject: [PATCH 19/59] fix: fix find command --- .github/workflows/tests-with-pydebug.yml | 4 ++-- .github/workflows/tests.yml | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/tests-with-pydebug.yml b/.github/workflows/tests-with-pydebug.yml index 722fca3a..1d86e36a 100644 --- a/.github/workflows/tests-with-pydebug.yml +++ b/.github/workflows/tests-with-pydebug.yml @@ -351,7 +351,7 @@ jobs: CORE_DUMP_FILES="$( find . -type d -path "./venv" -prune \ - -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + -o '(' -iname "core.*.[1-9]*" -o -iname "core_*.dmp" ')' -print )" if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "::error::Coredump files found, indicating a crash during tests." >&2 @@ -376,7 +376,7 @@ jobs: run: | CORE_DUMP_FILES="$( find . -type d -path "./venv" -prune \ - -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + -o '(' -iname "core.*.[1-9]*" -o -iname "core_*.dmp" ')' -print )" if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "Found core dumps:" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7cd1dda6..7edc4e89 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -216,7 +216,7 @@ jobs: fi CORE_DUMP_FILES="$( find . -type d -path "./venv" -prune \ - -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + -o '(' -iname "core.*.[1-9]*" -o -iname "core_*.dmp" ')' -print )" if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "::error::Coredump files found, indicating a crash during tests." >&2 @@ -275,7 +275,7 @@ jobs: fi CORE_DUMP_FILES="$( find . -type d -path "./venv" -prune \ - -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + -o '(' -iname "core.*.[1-9]*" -o -iname "core_*.dmp" ')' -print )" if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "::error::Coredump files found, indicating a crash during tests." >&2 @@ -319,7 +319,7 @@ jobs: fi CORE_DUMP_FILES="$( find . -type d -path "./venv" -prune \ - -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + -o '(' -iname "core.*.[1-9]*" -o -iname "core_*.dmp" ')' -print )" if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "::error::Coredump files found, indicating a crash during tests." >&2 @@ -343,7 +343,7 @@ jobs: CORE_DUMP_FILES="$( find . -type d -path "./venv" -prune \ - -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + -o '(' -iname "core.*.[1-9]*" -o -iname "core_*.dmp" ')' -print )" if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "::error::Coredump files found, indicating a crash during tests." >&2 @@ -368,7 +368,7 @@ jobs: run: | CORE_DUMP_FILES="$( find . -type d -path "./venv" -prune \ - -o -iname "core.*.[1-9]*" -o -iname "core_*.dmp" -print + -o '(' -iname "core.*.[1-9]*" -o -iname "core_*.dmp" ')' -print )" if [[ -n "${CORE_DUMP_FILES}" ]]; then echo "Found core dumps:" From f955df053de542ad43949391a9ed5150357c4e48 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 8 Dec 2025 12:51:02 +0800 Subject: [PATCH 20/59] refactor: change pointers to references --- include/optree/pymacros.h | 17 ++- include/optree/registry.h | 14 ++- include/optree/synchronization.h | 2 +- pyproject.toml | 1 - src/optree.cpp | 16 +-- src/registry.cpp | 188 +++++++++++++++---------------- 6 files changed, 118 insertions(+), 120 deletions(-) diff --git a/include/optree/pymacros.h b/include/optree/pymacros.h index 545bc523..47ecfa93 100644 --- a/include/optree/pymacros.h +++ b/include/optree/pymacros.h @@ -101,6 +101,8 @@ Py_Declare_ID(n_unnamed_fields); // structseq.n_unnamed_fields // NOLINTNEXTLINE[bugprone-macro-parentheses] #define NONZERO_OR_EMPTY(MACRO) ((MACRO + 0 != 0) || (0 - MACRO - 1 >= 0)) +using interpid_t = std::int64_t; + #if !defined(PYPY_VERSION) && \ (defined(PY_VERSION_HEX) && PY_VERSION_HEX >= 0x030E0000 /* Python 3.14 */) && \ (defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \ @@ -108,23 +110,26 @@ Py_Declare_ID(n_unnamed_fields); // structseq.n_unnamed_fields # define OPTREE_HAS_SUBINTERPRETER_SUPPORT 1 -[[nodiscard]] inline std::int64_t GetPyInterpreterID() { +[[nodiscard]] inline interpid_t GetPyInterpreterID() { PyInterpreterState *interp = PyInterpreterState_Get(); + if (PyErr_Occurred() != nullptr) [[unlikely]] { + throw py::error_already_set(); + } if (interp == nullptr) [[unlikely]] { throw std::runtime_error("Failed to get the current Python interpreter state."); } - const std::int64_t id = PyInterpreterState_GetID(interp); - if (id < 0) [[unlikely]] { - throw std::runtime_error("Failed to get the current Python interpreter ID (invalid ID)."); + const interpid_t interpid = PyInterpreterState_GetID(interp); + if (PyErr_Occurred() != nullptr) [[unlikely]] { + throw py::error_already_set(); } - return id; + return interpid; } #else # undef OPTREE_HAS_SUBINTERPRETER_SUPPORT -[[nodiscard]] inline constexpr std::int64_t GetPyInterpreterID() noexcept { +[[nodiscard]] inline constexpr interpid_t GetPyInterpreterID() noexcept { // Fallback for Python versions < 3.14 or when subinterpreter support is not available. return 0; } diff --git a/include/optree/registry.h b/include/optree/registry.h index d9c54671..cbded54a 100644 --- a/include/optree/registry.h +++ b/include/optree/registry.h @@ -29,6 +29,7 @@ limitations under the License. #include #include "optree/hashing.h" +#include "optree/pymacros.h" #include "optree/synchronization.h" namespace optree { @@ -153,15 +154,16 @@ class PyTreeTypeRegistry { std::unordered_map, RegistrationPtr>; using BuiltinsTypesSet = std::unordered_set; - template - [[nodiscard]] inline std::tuple - GetRegistrationsForInterpreterLocked() const; + // Get the registrations for the current Python interpreter. + [[nodiscard]] inline Py_ALWAYS_INLINE std:: + tuple + GetRegistrationsForCurrentPyInterpreterLocked() const; bool m_none_is_leaf = false; - std::unordered_map m_registrations{}; - std::unordered_map m_named_registrations{}; + std::unordered_map m_registrations{}; + std::unordered_map m_named_registrations{}; - static inline std::unordered_map sm_builtins_types{}; + static inline std::unordered_map sm_builtins_types{}; static inline read_write_mutex sm_mutex{}; static inline ssize_t sm_num_interpreters_alive = 0; static inline ssize_t sm_num_interpreters_seen = 0; diff --git a/include/optree/synchronization.h b/include/optree/synchronization.h index 461e82ee..83203f48 100644 --- a/include/optree/synchronization.h +++ b/include/optree/synchronization.h @@ -60,7 +60,7 @@ using scoped_lock = std::scoped_lock; using scoped_recursive_lock = std::scoped_lock; #if (defined(__APPLE__) /* header is not available on macOS build target */ && \ - PY_VERSION_HEX < /* Python 3.12.0 */ 0x030C00F0) + PY_VERSION_HEX < 0x030C00F0 /* Python 3.12.0 */) # undef HAVE_READ_WRITE_LOCK diff --git a/pyproject.toml b/pyproject.toml index 61fd419f..86c0838e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,6 @@ test = [ "pytest", "pytest-cov", "covdefaults", - "xdoctest", "rich", # Pin to minimum for compatibility test "typing-extensions == 4.6.0; python_version < '3.13' and platform_system == 'Linux'", diff --git a/src/optree.cpp b/src/optree.cpp index aebc7dd2..cfefbdec 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -157,14 +157,14 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] py::arg("namespace") = "") .def( "get_num_interpreters_seen", - []() -> size_t { + []() -> ssize_t { const scoped_read_lock lock{PyTreeTypeRegistry::sm_mutex}; return PyTreeTypeRegistry::sm_num_interpreters_seen; }, "Get the number of interpreters that have seen the registry.") .def( "get_num_interpreters_alive", - []() -> size_t { + []() -> ssize_t { const scoped_read_lock lock{PyTreeTypeRegistry::sm_mutex}; EXPECT_EQ(py::ssize_t_cast(PyTreeTypeRegistry::sm_builtins_types.size()), PyTreeTypeRegistry::sm_num_interpreters_alive, @@ -175,17 +175,17 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] "Get the number of alive interpreters that have seen the registry.") .def( "get_alive_interpreter_ids", - []() -> std::unordered_set { + []() -> std::unordered_set { const scoped_read_lock lock{PyTreeTypeRegistry::sm_mutex}; EXPECT_EQ(py::ssize_t_cast(PyTreeTypeRegistry::sm_builtins_types.size()), PyTreeTypeRegistry::sm_num_interpreters_alive, "The number of alive interpreters should match the size of the " "interpreter-scoped registered types map."); - std::unordered_set ids; - for (const auto &[id, _] : PyTreeTypeRegistry::sm_builtins_types) { - ids.insert(id); + std::unordered_set interpids; + for (const auto &[interpid, _] : PyTreeTypeRegistry::sm_builtins_types) { + interpids.insert(interpid); } - return ids; + return interpids; }, "Get the IDs of alive interpreters that have seen the registry.") .def("get_current_interpreter_id", @@ -207,7 +207,7 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] py::arg("namespace") = std::nullopt) .def("flatten", &PyTreeSpec::Flatten, - "Flattens a pytree.", + "Flatten a pytree.", py::arg("tree"), py::pos_only(), py::arg("leaf_predicate") = std::nullopt, diff --git a/src/registry.cpp b/src/registry.cpp index 4f48bfa7..9c9b97dc 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -15,14 +15,13 @@ limitations under the License. ================================================================================ */ -#include // std::make_shared -#include // std::optional -#include // std::ostringstream -#include // std::string -#include // std::tuple, std::make_tuple -#include // std::remove_const_t -#include // std::unordered_set -#include // std::move, std::make_pair +#include // std::make_shared +#include // std::optional +#include // std::ostringstream +#include // std::string +#include // std::tuple +#include // std::remove_const_t +#include // std::move, std::make_pair #include @@ -45,65 +44,58 @@ template template PyTreeTypeRegistry &PyTreeTypeRegistry::GetSingleton(); template PyTreeTypeRegistry &PyTreeTypeRegistry::GetSingleton(); -template -std::tuple -PyTreeTypeRegistry::GetRegistrationsForInterpreterLocked() const { - const auto interpreter_id = GetPyInterpreterID(); +std::tuple +PyTreeTypeRegistry::GetRegistrationsForCurrentPyInterpreterLocked() const { + const auto interpid = GetPyInterpreterID(); + EXPECT_NE(m_registrations.find(interpid), + m_registrations.end(), + "Interpreter ID " + std::to_string(interpid) + " not found in `m_registrations`."); EXPECT_NE( - m_registrations.find(interpreter_id), - m_registrations.end(), - "Interpreter ID " + std::to_string(interpreter_id) + " not found in `m_registrations`."); - EXPECT_NE(m_named_registrations.find(interpreter_id), - m_named_registrations.end(), - "Interpreter ID " + std::to_string(interpreter_id) + - " not found in `m_named_registrations`."); - EXPECT_NE( - sm_builtins_types.find(interpreter_id), - sm_builtins_types.end(), - "Interpreter ID " + std::to_string(interpreter_id) + " not found in `sm_builtins_types`."); - - const auto ®istrations = m_registrations.at(interpreter_id); - const auto &named_registrations = m_named_registrations.at(interpreter_id); - const auto &builtins_types = sm_builtins_types.at(interpreter_id); + m_named_registrations.find(interpid), + m_named_registrations.end(), + "Interpreter ID " + std::to_string(interpid) + " not found in `m_named_registrations`."); + EXPECT_NE(sm_builtins_types.find(interpid), + sm_builtins_types.end(), + "Interpreter ID " + std::to_string(interpid) + " not found in `sm_builtins_types`."); // NOLINTBEGIN[cppcoreguidelines-pro-type-const-cast] - return std::make_tuple(const_cast(®istrations), - const_cast(&named_registrations), - const_cast(&builtins_types)); + return {const_cast(m_registrations.at(interpid)), + const_cast(m_named_registrations.at(interpid)), + const_cast(sm_builtins_types.at(interpid))}; // NOLINTEND[cppcoreguidelines-pro-type-const-cast] } void PyTreeTypeRegistry::Init() { const scoped_write_lock lock{sm_mutex}; - const auto interpreter_id = GetPyInterpreterID(); + const auto interpid = GetPyInterpreterID(); - EXPECT_EQ(m_registrations.find(interpreter_id), + EXPECT_EQ(m_registrations.find(interpid), m_registrations.end(), - "Interpreter ID " + std::to_string(interpreter_id) + + "Interpreter ID " + std::to_string(interpid) + " is already initialized in `m_registrations`."); - EXPECT_EQ(m_named_registrations.find(interpreter_id), + EXPECT_EQ(m_named_registrations.find(interpid), m_named_registrations.end(), - "Interpreter ID " + std::to_string(interpreter_id) + + "Interpreter ID " + std::to_string(interpid) + " is already initialized in `m_named_registrations`."); if (!m_none_is_leaf) [[likely]] { - EXPECT_EQ(sm_builtins_types.find(interpreter_id), + EXPECT_EQ(sm_builtins_types.find(interpid), sm_builtins_types.end(), - "Interpreter ID " + std::to_string(interpreter_id) + + "Interpreter ID " + std::to_string(interpid) + " is already initialized in `sm_builtins_types`."); } else { - EXPECT_NE(sm_builtins_types.find(interpreter_id), + EXPECT_NE(sm_builtins_types.find(interpid), sm_builtins_types.end(), - "Interpreter ID " + std::to_string(interpreter_id) + + "Interpreter ID " + std::to_string(interpid) + " is not initialized in `sm_builtins_types`."); } - auto ®istrations = m_registrations.try_emplace(interpreter_id).first->second; - auto &named_registrations = m_named_registrations.try_emplace(interpreter_id).first->second; - auto &builtins_types = sm_builtins_types.try_emplace(interpreter_id).first->second; + auto ®istrations = m_registrations.try_emplace(interpid).first->second; + auto &named_registrations = m_named_registrations.try_emplace(interpid).first->second; + auto &builtins_types = sm_builtins_types.try_emplace(interpid).first->second; (void)named_registrations; // silence unused variable warning @@ -133,13 +125,13 @@ void PyTreeTypeRegistry::Init() { ssize_t PyTreeTypeRegistry::Size(const std::optional ®istry_namespace) const { const scoped_read_lock lock{sm_mutex}; - const auto [registrations, named_registrations, builtins_types] = - GetRegistrationsForInterpreterLocked(); + const auto &[registrations, named_registrations, builtins_types] = + GetRegistrationsForCurrentPyInterpreterLocked(); (void)builtins_types; // silence unused variable warning - ssize_t count = py::ssize_t_cast(registrations->size()); - for (const auto &[named_type, _] : *named_registrations) { + ssize_t count = py::ssize_t_cast(registrations.size()); + for (const auto &[named_type, _] : named_registrations) { if (!registry_namespace || named_type.first == *registry_namespace) [[likely]] { ++count; } @@ -155,10 +147,10 @@ template const std::string ®istry_namespace) { const auto ®istry = GetSingleton(); - const auto [registrations, named_registrations, builtins_types] = - registry.GetRegistrationsForInterpreterLocked(); + auto [registrations, named_registrations, builtins_types] = + registry.GetRegistrationsForCurrentPyInterpreterLocked(); - if (builtins_types->find(cls) != builtins_types->end()) [[unlikely]] { + if (builtins_types.find(cls) != builtins_types.end()) [[unlikely]] { throw py::value_error("PyTree type " + PyRepr(cls) + " is a built-in type and cannot be re-registered."); } @@ -170,7 +162,7 @@ template registration->unflatten_func = py::reinterpret_borrow(unflatten_func); registration->path_entry_type = py::reinterpret_borrow(path_entry_type); if (registry_namespace.empty()) [[unlikely]] { - if (!registrations->emplace(cls, std::move(registration)).second) [[unlikely]] { + if (!registrations.emplace(cls, std::move(registration)).second) [[unlikely]] { throw py::value_error("PyTree type " + PyRepr(cls) + " is already registered in the global namespace."); } @@ -193,7 +185,7 @@ template } } else [[likely]] { if (!named_registrations - ->emplace(std::make_pair(registry_namespace, cls), std::move(registration)) + .emplace(std::make_pair(registry_namespace, cls), std::move(registration)) .second) [[unlikely]] { std::ostringstream oss{}; oss << "PyTree type " << PyRepr(cls) << " is already registered in namespace " @@ -253,17 +245,17 @@ template const std::string ®istry_namespace) { const auto ®istry = GetSingleton(); - const auto [registrations, named_registrations, builtins_types] = - registry.GetRegistrationsForInterpreterLocked(); + auto [registrations, named_registrations, builtins_types] = + registry.GetRegistrationsForCurrentPyInterpreterLocked(); - if (builtins_types->find(cls) != builtins_types->end()) [[unlikely]] { + if (builtins_types.find(cls) != builtins_types.end()) [[unlikely]] { throw py::value_error("PyTree type " + PyRepr(cls) + " is a built-in type and cannot be unregistered."); } if (registry_namespace.empty()) [[unlikely]] { - const auto it = registrations->find(cls); - if (it == registrations->end()) [[unlikely]] { + const auto it = registrations.find(cls); + if (it == registrations.end()) [[unlikely]] { std::ostringstream oss{}; oss << "PyTree type " << PyRepr(cls) << " "; if (IsStructSequenceClass(cls)) [[unlikely]] { @@ -278,11 +270,11 @@ template throw py::value_error(oss.str()); } RegistrationPtr registration = it->second; - registrations->erase(it); + registrations.erase(it); return registration; } else [[likely]] { - const auto named_it = named_registrations->find(std::make_pair(registry_namespace, cls)); - if (named_it == named_registrations->end()) [[unlikely]] { + const auto named_it = named_registrations.find(std::make_pair(registry_namespace, cls)); + if (named_it == named_registrations.end()) [[unlikely]] { std::ostringstream oss{}; oss << "PyTree type " << PyRepr(cls) << " "; if (IsStructSequenceClass(cls)) [[unlikely]] { @@ -298,7 +290,7 @@ template throw py::value_error(oss.str()); } RegistrationPtr registration = named_it->second; - named_registrations->erase(named_it); + named_registrations.erase(named_it); return registration; } } @@ -327,17 +319,17 @@ template const auto ®istry = GetSingleton(); - const auto [registrations, named_registrations, _] = - registry.GetRegistrationsForInterpreterLocked(); + const auto &[registrations, named_registrations, _] = + registry.GetRegistrationsForCurrentPyInterpreterLocked(); if (!registry_namespace.empty()) [[unlikely]] { - const auto named_it = named_registrations->find(std::make_pair(registry_namespace, cls)); - if (named_it != named_registrations->end()) [[likely]] { + const auto named_it = named_registrations.find(std::make_pair(registry_namespace, cls)); + if (named_it != named_registrations.end()) [[likely]] { return named_it->second; } } - const auto it = registrations->find(cls); - return it != registrations->end() ? it->second : nullptr; + const auto it = registrations.find(cls); + return it != registrations.end() ? it->second : nullptr; } template PyTreeTypeRegistry::RegistrationPtr PyTreeTypeRegistry::Lookup( @@ -385,28 +377,28 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( /*static*/ void PyTreeTypeRegistry::Clear() { const scoped_write_lock lock{sm_mutex}; - const auto interpreter_id = GetPyInterpreterID(); + const auto interpid = GetPyInterpreterID(); auto ®istry1 = GetSingleton(); auto ®istry2 = GetSingleton(); - const auto [registrations1, named_registrations1, builtins_types] = - registry1.GetRegistrationsForInterpreterLocked(); - const auto [registrations2, named_registrations2, builtins_types_] = - registry2.GetRegistrationsForInterpreterLocked(); + auto [registrations1, named_registrations1, builtins_types] = + registry1.GetRegistrationsForCurrentPyInterpreterLocked(); + auto [registrations2, named_registrations2, builtins_types_] = + registry2.GetRegistrationsForCurrentPyInterpreterLocked(); - EXPECT_LE(builtins_types->size(), registrations1->size()); - EXPECT_EQ(registrations1->size(), registrations2->size() + 1); - EXPECT_EQ(named_registrations1->size(), named_registrations2->size()); - EXPECT_EQ(builtins_types, builtins_types_); + EXPECT_LE(builtins_types.size(), registrations1.size()); + EXPECT_EQ(registrations1.size(), registrations2.size() + 1); + EXPECT_EQ(named_registrations1.size(), named_registrations2.size()); + EXPECT_EQ(&builtins_types, &builtins_types_); #if defined(Py_DEBUG) - for (const auto &cls : *builtins_types) { - EXPECT_NE(registrations1->find(cls), registrations1->end()); + for (const auto &cls : builtins_types) { + EXPECT_NE(registrations1.find(cls), registrations1.end()); } - for (const auto &[cls2, registration2] : *registrations2) { - const auto it1 = registrations1->find(cls2); - EXPECT_NE(it1, registrations1->end()); + for (const auto &[cls2, registration2] : registrations2) { + const auto it1 = registrations1.find(cls2); + EXPECT_NE(it1, registrations1.end()); const auto ®istration1 = it1->second; EXPECT_TRUE(registration1->type.is(registration2->type)); @@ -414,9 +406,9 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( EXPECT_TRUE(registration1->unflatten_func.is(registration2->unflatten_func)); EXPECT_TRUE(registration1->path_entry_type.is(registration2->path_entry_type)); } - for (const auto &[named_cls2, registration2] : *named_registrations2) { - const auto it1 = named_registrations1->find(named_cls2); - EXPECT_NE(it1, named_registrations1->end()); + for (const auto &[named_cls2, registration2] : named_registrations2) { + const auto it1 = named_registrations1.find(named_cls2); + EXPECT_NE(it1, named_registrations1.end()); const auto ®istration1 = it1->second; EXPECT_TRUE(registration1->type.is(registration2->type)); @@ -428,30 +420,30 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( EXPECT_EQ(py::ssize_t_cast(sm_builtins_types.size()), sm_num_interpreters_alive); - for (const auto &[_, registration] : *registrations1) { + for (const auto &[_, registration] : registrations1) { registration->type.dec_ref(); registration->flatten_func.dec_ref(); registration->unflatten_func.dec_ref(); registration->path_entry_type.dec_ref(); } - for (const auto &[_, registration] : *named_registrations1) { + for (const auto &[_, registration] : named_registrations1) { registration->type.dec_ref(); registration->flatten_func.dec_ref(); registration->unflatten_func.dec_ref(); registration->path_entry_type.dec_ref(); } - builtins_types->clear(); - registrations1->clear(); - named_registrations1->clear(); - registrations2->clear(); - named_registrations2->clear(); - - sm_builtins_types.erase(interpreter_id); - registry1.m_registrations.erase(interpreter_id); - registry1.m_named_registrations.erase(interpreter_id); - registry2.m_registrations.erase(interpreter_id); - registry2.m_named_registrations.erase(interpreter_id); + builtins_types.clear(); + registrations1.clear(); + named_registrations1.clear(); + registrations2.clear(); + named_registrations2.clear(); + + sm_builtins_types.erase(interpid); + registry1.m_registrations.erase(interpid); + registry1.m_named_registrations.erase(interpid); + registry2.m_registrations.erase(interpid); + registry2.m_named_registrations.erase(interpid); --sm_num_interpreters_alive; } From e2463e214e7b4c20fb9fd66bbadcab4525f9576e Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 8 Dec 2025 13:59:56 +0800 Subject: [PATCH 21/59] test: add import tests --- tests/concurrent/test_subinterpreters.py | 71 ++++++++++++++++++++++++ tests/helpers.py | 22 ++++++++ tests/test_ops.py | 21 +------ 3 files changed, 95 insertions(+), 19 deletions(-) diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py index 48b8d526..bf6306fe 100644 --- a/tests/concurrent/test_subinterpreters.py +++ b/tests/concurrent/test_subinterpreters.py @@ -17,6 +17,7 @@ import contextlib import random import sys +import textwrap import pytest @@ -28,6 +29,7 @@ WASM, Py_DEBUG, Py_GIL_DISABLED, + check_script_in_subprocess, ) @@ -50,9 +52,11 @@ if Py_GIL_DISABLED and not Py_DEBUG: NUM_WORKERS = 32 NUM_FUTURES = 128 + NUM_FLAKY_RERUNS = 8 else: NUM_WORKERS = 4 NUM_FUTURES = 16 + NUM_FLAKY_RERUNS = 5 EXECUTOR = InterpreterPoolExecutor(max_workers=NUM_WORKERS) @@ -157,3 +161,70 @@ def test_import(): random.shuffle(subinterpreters) for subinterpreter in subinterpreters: assert subinterpreter.call(check_module_importable) == expected + + +def test_import_in_subinterpreter_after_main(): + check_script_in_subprocess( + textwrap.dedent( + """ + import contextlib + import gc + from concurrent import interpreters + + import optree + + subinterpreter = None + with contextlib.closing(interpreters.create()) as subinterpreter: + subinterpreter.exec('import optree') + + del optree, subinterpreter + for _ in range(10): + gc.collect() + """, + ).strip(), + rerun=NUM_FLAKY_RERUNS, + ) + + +def test_import_in_subinterpreter_before_main(): + check_script_in_subprocess( + textwrap.dedent( + """ + import contextlib + import gc + from concurrent import interpreters + + subinterpreter = None + with contextlib.closing(interpreters.create()) as subinterpreter: + subinterpreter.exec('import optree') + + import optree + del optree, subinterpreter + for _ in range(10): + gc.collect() + """, + ).strip(), + rerun=NUM_FLAKY_RERUNS, + ) + + +def test_import_in_subinterpreters_concurrently(): + check_script_in_subprocess( + textwrap.dedent( + f""" + from concurrent.futures import InterpreterPoolExecutor, as_completed + + def check_import(): + import optree + + if optree._C.get_registry_size() != 8: + raise RuntimeError('registry size mismatch') + + with InterpreterPoolExecutor(max_workers={NUM_WORKERS}) as executor: + futures = [executor.submit(check_import) for _ in range({NUM_FUTURES})] + for future in as_completed(futures): + future.result() + """, + ).strip(), + rerun=NUM_FLAKY_RERUNS, + ) diff --git a/tests/helpers.py b/tests/helpers.py index c004ec53..baa47882 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -20,7 +20,9 @@ import functools import gc import itertools +import os import platform +import subprocess import sys import sysconfig import time @@ -154,6 +156,26 @@ def wrapper(*args, **kwargs): return wrapper +def check_script_in_subprocess(script, /, *, rerun=1): + env = { + key: value + for key, value in os.environ.items() + if not key.startswith(('PYTHON', 'PYTEST', 'COV_')) + } + for _ in range(rerun): + assert ( + subprocess.check_output( + [sys.executable, '-Walways', '-Werror', '-c', script], + stderr=subprocess.STDOUT, + text=True, + encoding='utf-8', + cwd=TEST_ROOT, + env=env, + ) + == '' + ) + + MISSING = object() diff --git a/tests/test_ops.py b/tests/test_ops.py index 6ceebe4d..02752bc3 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -19,11 +19,9 @@ import functools import itertools import operator -import os import pickle import platform import re -import subprocess import sys from collections import OrderedDict, defaultdict, deque @@ -34,7 +32,6 @@ GLOBAL_NAMESPACE, IS_LEAF_FUNCTIONS, LEAVES, - TEST_ROOT, TREE_ACCESSORS, TREE_PATHS, TREES, @@ -45,6 +42,7 @@ Py_DEBUG, always, assert_equal_type_and_value, + check_script_in_subprocess, is_list, is_none, is_tuple, @@ -60,22 +58,7 @@ @skipif_android @skipif_ios def test_import_no_warnings(): - env = { - key: value - for key, value in os.environ.items() - if not key.startswith(('PYTHON', 'PYTEST', 'COV_')) - } - assert ( - subprocess.check_output( - [sys.executable, '-Walways', '-Werror', '-c', 'import optree'], - stderr=subprocess.STDOUT, - text=True, - encoding='utf-8', - cwd=TEST_ROOT, - env=env, - ) - == '' - ) + check_script_in_subprocess('import optree') def test_max_depth(): From be1fb8e2766cbdb5498c984d92c0e38fd9b8401f Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 8 Dec 2025 17:15:24 +0800 Subject: [PATCH 22/59] fix: do not use `py::gil_safe_call_once_and_store` for subinterpreters --- include/optree/pymacros.h | 68 +++++++++++++++++++++++---------------- 1 file changed, 41 insertions(+), 27 deletions(-) diff --git a/include/optree/pymacros.h b/include/optree/pymacros.h index 47ecfa93..281057a0 100644 --- a/include/optree/pymacros.h +++ b/include/optree/pymacros.h @@ -32,6 +32,17 @@ limitations under the License. # error "pybind11 2.12.0 or newer is required." #endif +// NOLINTNEXTLINE[bugprone-macro-parentheses] +#define NONZERO_OR_EMPTY(MACRO) ((MACRO + 0 != 0) || (0 - MACRO - 1 >= 0)) +#if !defined(PYPY_VERSION) && \ + (defined(PY_VERSION_HEX) && PY_VERSION_HEX >= 0x030E0000 /* Python 3.14 */) && \ + (defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \ + NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT)) +# define OPTREE_HAS_SUBINTERPRETER_SUPPORT 1 +#else +# undef OPTREE_HAS_SUBINTERPRETER_SUPPORT +#endif + namespace py = pybind11; #if !defined(Py_ALWAYS_INLINE) @@ -60,22 +71,35 @@ inline constexpr Py_ALWAYS_INLINE bool Py_IsConstant(PyObject *x) noexcept { } #define Py_IsConstant(x) Py_IsConstant(x) -#define Py_Declare_ID(name) \ - namespace { \ - [[nodiscard]] inline PyObject *Py_ID_##name() { \ - PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store storage; \ - return storage \ - .call_once_and_store_result([]() -> PyObject * { \ - PyObject * const ptr = PyUnicode_InternFromString(#name); \ - if (ptr == nullptr) [[unlikely]] { \ - throw py::error_already_set(); \ - } \ - Py_INCREF(ptr); /* leak a reference on purpose */ \ - return ptr; \ - }) \ - .get_stored(); \ - } \ - } // namespace +#if defined(OPTREE_HAS_SUBINTERPRETER_SUPPORT) +# define Py_Declare_ID(name) \ + namespace { \ + [[nodiscard]] inline PyObject *Py_ID_##name() { \ + PyObject * const ptr = PyUnicode_InternFromString(#name); \ + if (ptr == nullptr) [[unlikely]] { \ + throw py::error_already_set(); \ + } \ + return ptr; \ + } \ + } // namespace +#else +# define Py_Declare_ID(name) \ + namespace { \ + [[nodiscard]] inline PyObject *Py_ID_##name() { \ + PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store storage; \ + return storage \ + .call_once_and_store_result([]() -> PyObject * { \ + PyObject * const ptr = PyUnicode_InternFromString(#name); \ + if (ptr == nullptr) [[unlikely]] { \ + throw py::error_already_set(); \ + } \ + Py_INCREF(ptr); /* leak a reference on purpose */ \ + return ptr; \ + }) \ + .get_stored(); \ + } \ + } // namespace +#endif #define Py_Get_ID(name) (::Py_ID_##name()) @@ -98,17 +122,9 @@ Py_Declare_ID(n_fields); // structseq.n_fields Py_Declare_ID(n_sequence_fields); // structseq.n_sequence_fields Py_Declare_ID(n_unnamed_fields); // structseq.n_unnamed_fields -// NOLINTNEXTLINE[bugprone-macro-parentheses] -#define NONZERO_OR_EMPTY(MACRO) ((MACRO + 0 != 0) || (0 - MACRO - 1 >= 0)) - using interpid_t = std::int64_t; -#if !defined(PYPY_VERSION) && \ - (defined(PY_VERSION_HEX) && PY_VERSION_HEX >= 0x030E0000 /* Python 3.14 */) && \ - (defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \ - NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT)) - -# define OPTREE_HAS_SUBINTERPRETER_SUPPORT 1 +#if defined(OPTREE_HAS_SUBINTERPRETER_SUPPORT) [[nodiscard]] inline interpid_t GetPyInterpreterID() { PyInterpreterState *interp = PyInterpreterState_Get(); @@ -127,8 +143,6 @@ using interpid_t = std::int64_t; #else -# undef OPTREE_HAS_SUBINTERPRETER_SUPPORT - [[nodiscard]] inline constexpr interpid_t GetPyInterpreterID() noexcept { // Fallback for Python versions < 3.14 or when subinterpreter support is not available. return 0; From 3495e721b887d77e8fb808e9a19b7ac6f9ae0789 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 8 Dec 2025 20:50:53 +0800 Subject: [PATCH 23/59] chore: reorgainize code --- include/optree/registry.h | 44 ++++++++++++++++++++++++ include/optree/treespec.h | 2 +- optree/_C.pyi | 2 +- src/optree.cpp | 71 ++++++++++----------------------------- 4 files changed, 64 insertions(+), 55 deletions(-) diff --git a/include/optree/registry.h b/include/optree/registry.h index cbded54a..e3e92e6f 100644 --- a/include/optree/registry.h +++ b/include/optree/registry.h @@ -28,6 +28,7 @@ limitations under the License. #include +#include "optree/exceptions.h" #include "optree/hashing.h" #include "optree/pymacros.h" #include "optree/synchronization.h" @@ -113,6 +114,7 @@ class PyTreeTypeRegistry { const py::object &path_entry_type, const std::string ®istry_namespace = ""); + // Unregister a previously registered custom type. static void Unregister(const py::object &cls, const std::string ®istry_namespace = ""); // Find the custom type registration for `type`. Returns nullptr if none exists. @@ -126,6 +128,48 @@ class PyTreeTypeRegistry { RegistrationPtr &custom, // NOLINT[runtime/references] const std::string ®istry_namespace); + // Get the number of registered types. + [[nodiscard]] static inline Py_ALWAYS_INLINE ssize_t GetRegistrySize( + const std::optional ®istry_namespace = std::nullopt) { + const ssize_t count = GetSingleton().Size(registry_namespace); + EXPECT_EQ(count, + GetSingleton().Size(registry_namespace) + 1, + "The number of registered types in the two registries should match " + "up to the extra None type in the NoneIsNode registry."); + return count; + } + + // Get the number of alive interpreters that have seen the registry. + [[nodiscard]] static inline Py_ALWAYS_INLINE ssize_t GetNumInterpretersAlive() { + const scoped_read_lock lock{sm_mutex}; + return sm_num_interpreters_seen; + } + + // Get the number of interpreters that have seen the registry. + [[nodiscard]] static inline Py_ALWAYS_INLINE ssize_t GetNumInterpretersSeen() { + const scoped_read_lock lock{sm_mutex}; + EXPECT_EQ(py::ssize_t_cast(sm_builtins_types.size()), + sm_num_interpreters_alive, + "The number of alive interpreters should match the size of the " + "interpreter-scoped registered types map."); + return sm_num_interpreters_alive; + } + + // Get the IDs of alive interpreters that have seen the registry. + [[nodiscard]] static inline Py_ALWAYS_INLINE std::unordered_set + GetAliveInterpreterIDs() { + const scoped_read_lock lock{sm_mutex}; + EXPECT_EQ(py::ssize_t_cast(sm_builtins_types.size()), + sm_num_interpreters_alive, + "The number of alive interpreters should match the size of the " + "interpreter-scoped registered types map."); + std::unordered_set interpids; + for (const auto &[interpid, _] : sm_builtins_types) { + interpids.insert(interpid); + } + return interpids; + } + friend void BuildModule(py::module_ &mod); // NOLINT[runtime/references] private: diff --git a/include/optree/treespec.h b/include/optree/treespec.h index fff08b8d..93c6e110 100644 --- a/include/optree/treespec.h +++ b/include/optree/treespec.h @@ -288,7 +288,7 @@ class PyTreeSpec { private: using RegistrationPtr = PyTreeTypeRegistry::RegistrationPtr; - using ThreadedIdentity = std::pair; + using ThreadedIdentity = std::pair; struct Node { PyTreeKind kind = PyTreeKind::Leaf; diff --git a/optree/_C.pyi b/optree/_C.pyi index 854f8c76..aa8d96e3 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -214,8 +214,8 @@ def set_dict_insertion_ordered( /, namespace: str = '', ) -> None: ... +def get_registry_size(namespace: str | None = None) -> int: ... def get_num_interpreters_seen() -> int: ... def get_num_interpreters_alive() -> int: ... def get_alive_interpreter_ids() -> set[int]: ... def get_current_interpreter_id() -> int: ... -def get_registry_size(namespace: str | None = None) -> int: ... diff --git a/src/optree.cpp b/src/optree.cpp index cfefbdec..1347ced4 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -17,12 +17,11 @@ limitations under the License. #include "optree/optree.h" -#include // std::{not_,}equal_to, std::less{,_equal}, std::greater{,_equal} -#include // std::unique_ptr -#include // std::optional, std::nullopt -#include // std::string -#include // std::unordered_set -#include // std::move +#include // std::{not_,}equal_to, std::less{,_equal}, std::greater{,_equal} +#include // std::unique_ptr +#include // std::optional, std::nullopt +#include // std::string +#include // std::move #include #include @@ -155,56 +154,22 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] py::arg("mode"), py::pos_only(), py::arg("namespace") = "") - .def( - "get_num_interpreters_seen", - []() -> ssize_t { - const scoped_read_lock lock{PyTreeTypeRegistry::sm_mutex}; - return PyTreeTypeRegistry::sm_num_interpreters_seen; - }, - "Get the number of interpreters that have seen the registry.") - .def( - "get_num_interpreters_alive", - []() -> ssize_t { - const scoped_read_lock lock{PyTreeTypeRegistry::sm_mutex}; - EXPECT_EQ(py::ssize_t_cast(PyTreeTypeRegistry::sm_builtins_types.size()), - PyTreeTypeRegistry::sm_num_interpreters_alive, - "The number of alive interpreters should match the size of the " - "interpreter-scoped registered types map."); - return PyTreeTypeRegistry::sm_num_interpreters_alive; - }, - "Get the number of alive interpreters that have seen the registry.") - .def( - "get_alive_interpreter_ids", - []() -> std::unordered_set { - const scoped_read_lock lock{PyTreeTypeRegistry::sm_mutex}; - EXPECT_EQ(py::ssize_t_cast(PyTreeTypeRegistry::sm_builtins_types.size()), - PyTreeTypeRegistry::sm_num_interpreters_alive, - "The number of alive interpreters should match the size of the " - "interpreter-scoped registered types map."); - std::unordered_set interpids; - for (const auto &[interpid, _] : PyTreeTypeRegistry::sm_builtins_types) { - interpids.insert(interpid); - } - return interpids; - }, - "Get the IDs of alive interpreters that have seen the registry.") + .def("get_registry_size", + &PyTreeTypeRegistry::GetRegistrySize, + "Get the number of registered types.", + py::arg("namespace") = std::nullopt) + .def("get_num_interpreters_seen", + &PyTreeTypeRegistry::GetNumInterpretersSeen, + "Get the number of interpreters that have seen the registry.") + .def("get_num_interpreters_alive", + &PyTreeTypeRegistry::GetNumInterpretersAlive, + "Get the number of alive interpreters that have seen the registry.") + .def("get_alive_interpreter_ids", + &PyTreeTypeRegistry::GetAliveInterpreterIDs, + "Get the IDs of alive interpreters that have seen the registry.") .def("get_current_interpreter_id", &GetPyInterpreterID, "Get the ID of the current interpreter.") - .def( - "get_registry_size", - [](const std::optional ®istry_namespace) { - const ssize_t count = - PyTreeTypeRegistry::GetSingleton().Size(registry_namespace); - EXPECT_EQ( - count, - PyTreeTypeRegistry::GetSingleton().Size(registry_namespace) + 1, - "The number of registered types in the two registries should match " - "up to the extra None type in the NoneIsNode registry."); - return count; - }, - "Get the number of registered types.", - py::arg("namespace") = std::nullopt) .def("flatten", &PyTreeSpec::Flatten, "Flatten a pytree.", From 2278070cf564622b5f7339f9c2d5186fc0969122 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 8 Dec 2025 23:02:53 +0800 Subject: [PATCH 24/59] feat: add function `get_main_interpreter_id()` --- include/optree/pymacros.h | 24 ++++++++++++-- include/optree/registry.h | 8 ++--- include/optree/treespec.h | 6 ++-- optree/_C.pyi | 1 + src/optree.cpp | 5 ++- src/registry.cpp | 6 ++-- tests/concurrent/test_subinterpreters.py | 42 ++++++++++++++++++++---- 7 files changed, 72 insertions(+), 20 deletions(-) diff --git a/include/optree/pymacros.h b/include/optree/pymacros.h index 281057a0..1fd2e620 100644 --- a/include/optree/pymacros.h +++ b/include/optree/pymacros.h @@ -126,7 +126,7 @@ using interpid_t = std::int64_t; #if defined(OPTREE_HAS_SUBINTERPRETER_SUPPORT) -[[nodiscard]] inline interpid_t GetPyInterpreterID() { +[[nodiscard]] inline interpid_t GetCurrentPyInterpreterID() { PyInterpreterState *interp = PyInterpreterState_Get(); if (PyErr_Occurred() != nullptr) [[unlikely]] { throw py::error_already_set(); @@ -141,9 +141,29 @@ using interpid_t = std::int64_t; return interpid; } +[[nodiscard]] inline interpid_t GetMainPyInterpreterID() { + PyInterpreterState *interp = PyInterpreterState_Main(); + if (PyErr_Occurred() != nullptr) [[unlikely]] { + throw py::error_already_set(); + } + if (interp == nullptr) [[unlikely]] { + throw std::runtime_error("Failed to get the main Python interpreter state."); + } + const interpid_t interpid = PyInterpreterState_GetID(interp); + if (PyErr_Occurred() != nullptr) [[unlikely]] { + throw py::error_already_set(); + } + return interpid; +} + #else -[[nodiscard]] inline constexpr interpid_t GetPyInterpreterID() noexcept { +[[nodiscard]] inline constexpr interpid_t GetCurrentPyInterpreterID() noexcept { + // Fallback for Python versions < 3.14 or when subinterpreter support is not available. + return 0; +} + +[[nodiscard]] inline constexpr interpid_t GetMainPyInterpreterID() noexcept { // Fallback for Python versions < 3.14 or when subinterpreter support is not available. return 0; } diff --git a/include/optree/registry.h b/include/optree/registry.h index e3e92e6f..f66218df 100644 --- a/include/optree/registry.h +++ b/include/optree/registry.h @@ -106,7 +106,7 @@ class PyTreeTypeRegistry { [[nodiscard]] ssize_t Size( const std::optional ®istry_namespace = std::nullopt) const; - // Register a new custom type. Objects of `cls` will be treated as container node types in + // Register a new custom type. Objects of type `cls` will be treated as container node types in // PyTrees. static void Register(const py::object &cls, const py::function &flatten_func, @@ -117,7 +117,7 @@ class PyTreeTypeRegistry { // Unregister a previously registered custom type. static void Unregister(const py::object &cls, const std::string ®istry_namespace = ""); - // Find the custom type registration for `type`. Returns nullptr if none exists. + // Find the custom type registration for `type`. Return nullptr if none exists. template [[nodiscard]] static RegistrationPtr Lookup(const py::object &cls, const std::string ®istry_namespace); @@ -187,10 +187,10 @@ class PyTreeTypeRegistry { [[nodiscard]] static RegistrationPtr UnregisterImpl(const py::object &cls, const std::string ®istry_namespace); - // Initialize the registry for a given interpreter. + // Initialize the registry for the current interpreter. void Init(); - // Clear the registry on cleanup. + // Clear the registry on cleanup for the current interpreter. static void Clear(); using RegistrationsMap = std::unordered_map; diff --git a/include/optree/treespec.h b/include/optree/treespec.h index 93c6e110..1ff25883 100644 --- a/include/optree/treespec.h +++ b/include/optree/treespec.h @@ -333,16 +333,16 @@ class PyTreeSpec { // The registry namespace used to resolve the custom pytree node types. std::string m_namespace{}; - // Helper that returns the string representation of a node kind. + // Return the string representation of a node kind. [[nodiscard]] static std::string NodeKindToString(const Node &node); - // Helper that manufactures an instance of a node given its children. + // Manufacture an instance of a node given its children. [[nodiscard]] static py::object MakeNode( const Node &node, const py::object children[], // NOLINT[hicpp-avoid-c-arrays] const size_t &num_children); - // Helper that identifies the path entry class for a node. + // Identify the path entry class for a node. [[nodiscard]] static py::object GetPathEntryType(const Node &node); // Recursive helper used to implement Flatten(). diff --git a/optree/_C.pyi b/optree/_C.pyi index aa8d96e3..61932d33 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -219,3 +219,4 @@ def get_num_interpreters_seen() -> int: ... def get_num_interpreters_alive() -> int: ... def get_alive_interpreter_ids() -> set[int]: ... def get_current_interpreter_id() -> int: ... +def get_main_interpreter_id() -> int: ... diff --git a/src/optree.cpp b/src/optree.cpp index 1347ced4..9e9b2e5c 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -168,8 +168,11 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] &PyTreeTypeRegistry::GetAliveInterpreterIDs, "Get the IDs of alive interpreters that have seen the registry.") .def("get_current_interpreter_id", - &GetPyInterpreterID, + &GetCurrentPyInterpreterID, "Get the ID of the current interpreter.") + .def("get_main_interpreter_id", + &GetMainPyInterpreterID, + "Get the ID of the main interpreter.") .def("flatten", &PyTreeSpec::Flatten, "Flatten a pytree.", diff --git a/src/registry.cpp b/src/registry.cpp index 9c9b97dc..2204bf8f 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -48,7 +48,7 @@ std::tuple PyTreeTypeRegistry::GetRegistrationsForCurrentPyInterpreterLocked() const { - const auto interpid = GetPyInterpreterID(); + const auto interpid = GetCurrentPyInterpreterID(); EXPECT_NE(m_registrations.find(interpid), m_registrations.end(), @@ -71,7 +71,7 @@ PyTreeTypeRegistry::GetRegistrationsForCurrentPyInterpreterLocked() const { void PyTreeTypeRegistry::Init() { const scoped_write_lock lock{sm_mutex}; - const auto interpid = GetPyInterpreterID(); + const auto interpid = GetCurrentPyInterpreterID(); EXPECT_EQ(m_registrations.find(interpid), m_registrations.end(), @@ -377,7 +377,7 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( /*static*/ void PyTreeTypeRegistry::Clear() { const scoped_write_lock lock{sm_mutex}; - const auto interpid = GetPyInterpreterID(); + const auto interpid = GetCurrentPyInterpreterID(); auto ®istry1 = GetSingleton(); auto ®istry2 = GetSingleton(); diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py index f227cc74..7e9aa5a8 100644 --- a/tests/concurrent/test_subinterpreters.py +++ b/tests/concurrent/test_subinterpreters.py @@ -104,16 +104,43 @@ def check_module_importable(): 'g': collections.defaultdict(list, h=collections.deque([7, 8, 9], maxlen=10)), } - flat, spec = optree._C.flatten(tree) - reconstructed = spec.unflatten(flat) - if reconstructed != tree: + leaves1, treespec1 = optree.tree_flatten(tree, none_is_leaf=False) + reconstructed1 = optree.tree_unflatten(treespec1, leaves1) + if reconstructed1 != tree: raise RuntimeError('unflatten/flatten mismatch') - if spec.num_leaves != 9: - raise RuntimeError(f'num_leaves mismatch: ({flat}, {spec})') - if flat != [1, 2, 3, 4, 5, 6, 7, 8, 9]: - raise RuntimeError(f'flattened leaves mismatch: ({flat}, {spec})') + if treespec1.num_leaves != len(leaves1): + raise RuntimeError(f'num_leaves mismatch: ({leaves1}, {treespec1})') + if leaves1 != [1, 2, 3, 4, 5, 6, 7, 8, 9]: + raise RuntimeError(f'flattened leaves mismatch: ({leaves1}, {treespec1})') + + leaves2, treespec2 = optree.tree_flatten(tree, none_is_leaf=True) + reconstructed2 = optree.tree_unflatten(treespec2, leaves2) + if reconstructed2 != tree: + raise RuntimeError('unflatten/flatten mismatch') + if treespec2.num_leaves != len(leaves2): + raise RuntimeError(f'num_leaves mismatch: ({leaves2}, {treespec2})') + if leaves2 != [ + 1, + 2, + 3, + 4, + None, + 5, + 6, + *([None] * (time.struct_time.n_sequence_fields - 1)), + 7, + 8, + 9, + ]: + raise RuntimeError(f'flattened leaves mismatch: ({leaves2}, {treespec2})') + + _ = optree.tree_flatten_with_path(tree, none_is_leaf=False) + _ = optree.tree_flatten_with_path(tree, none_is_leaf=True) + # _ = optree.tree_flatten_with_accessor(tree, none_is_leaf=False) + # _ = optree.tree_flatten_with_accessor(tree, none_is_leaf=True) return ( + optree._C.get_main_interpreter_id(), id(type(None)), id(tuple), id(list), @@ -126,6 +153,7 @@ def test_import(): import collections expected = ( + 0, id(type(None)), id(tuple), id(list), From 8c36df25d510c6d45746a2e8e751da2b74026109 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 10 Dec 2025 12:13:43 +0800 Subject: [PATCH 25/59] chore: reorgainize code --- src/optree.cpp | 2 +- tests/concurrent/test_subinterpreters.py | 30 ++++++++++-------------- tests/helpers.py | 27 ++++++++++----------- tests/test_ops.py | 2 +- tests/test_treespec.py | 12 ++-------- 5 files changed, 29 insertions(+), 44 deletions(-) diff --git a/src/optree.cpp b/src/optree.cpp index 9e9b2e5c..976f96d0 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -564,7 +564,7 @@ PYBIND11_MODULE(_C, mod, py::mod_gil_not_used()) #else PYBIND11_MODULE(_C, mod) #endif +// NOLINTEND[cppcoreguidelines-pro-bounds-pointer-arithmetic,cppcoreguidelines-pro-type-vararg] { optree::BuildModule(mod); } -// NOLINTEND[cppcoreguidelines-pro-bounds-pointer-arithmetic,cppcoreguidelines-pro-type-vararg] diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py index 7e9aa5a8..c62ceaf1 100644 --- a/tests/concurrent/test_subinterpreters.py +++ b/tests/concurrent/test_subinterpreters.py @@ -192,9 +192,8 @@ def test_import(): def test_import_in_subinterpreter_after_main(): - check_script_in_subprocess( - textwrap.dedent( - """ + script = textwrap.dedent( + """ import contextlib import gc from concurrent import interpreters @@ -209,15 +208,13 @@ def test_import_in_subinterpreter_after_main(): for _ in range(10): gc.collect() """, - ).strip(), - rerun=NUM_FLAKY_RERUNS, - ) + ).strip() + check_script_in_subprocess(script, output='', rerun=NUM_FLAKY_RERUNS) def test_import_in_subinterpreter_before_main(): - check_script_in_subprocess( - textwrap.dedent( - """ + script = textwrap.dedent( + """ import contextlib import gc from concurrent import interpreters @@ -231,15 +228,13 @@ def test_import_in_subinterpreter_before_main(): for _ in range(10): gc.collect() """, - ).strip(), - rerun=NUM_FLAKY_RERUNS, - ) + ).strip() + check_script_in_subprocess(script, output='', rerun=NUM_FLAKY_RERUNS) def test_import_in_subinterpreters_concurrently(): - check_script_in_subprocess( - textwrap.dedent( - f""" + script = textwrap.dedent( + f""" from concurrent.futures import InterpreterPoolExecutor, as_completed def check_import(): @@ -253,6 +248,5 @@ def check_import(): for future in as_completed(futures): future.result() """, - ).strip(), - rerun=NUM_FLAKY_RERUNS, - ) + ).strip() + check_script_in_subprocess(script, output='', rerun=NUM_FLAKY_RERUNS) diff --git a/tests/helpers.py b/tests/helpers.py index baa47882..80f68108 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -156,24 +156,23 @@ def wrapper(*args, **kwargs): return wrapper -def check_script_in_subprocess(script, /, *, rerun=1): +def check_script_in_subprocess(script, /, *, output, env=None, cwd=TEST_ROOT, rerun=1): + if env is None: + env = os.environ env = { - key: value - for key, value in os.environ.items() - if not key.startswith(('PYTHON', 'PYTEST', 'COV_')) + key: value for key, value in env.items() if not key.startswith(('PYTHON', 'PYTEST', 'COV_')) } for _ in range(rerun): - assert ( - subprocess.check_output( - [sys.executable, '-Walways', '-Werror', '-c', script], - stderr=subprocess.STDOUT, - text=True, - encoding='utf-8', - cwd=TEST_ROOT, - env=env, - ) - == '' + result = subprocess.check_output( + [sys.executable, '-Walways', '-Werror', '-c', script], + stderr=subprocess.STDOUT, + text=True, + encoding='utf-8', + cwd=cwd, + env=env, ) + if output is not None: + assert result == output MISSING = object() diff --git a/tests/test_ops.py b/tests/test_ops.py index 02752bc3..843ca9ae 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -58,7 +58,7 @@ @skipif_android @skipif_ios def test_import_no_warnings(): - check_script_in_subprocess('import optree') + check_script_in_subprocess('import optree', output='') def test_max_depth(): diff --git a/tests/test_treespec.py b/tests/test_treespec.py index d07f3f3f..900572e8 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -43,6 +43,7 @@ MyAnotherDict, MyDict, Py_DEBUG, + check_script_in_subprocess, disable_systrace, gc_collect, parametrize, @@ -70,11 +71,6 @@ def test_treespec_construct(): with pytest.raises(TypeError, match=re.escape('No constructor defined!')): treespec.__init__() - env = { - key: value - for key, value in os.environ.items() - if not key.startswith(('PYTHON', 'PYTEST', 'COV_')) - } script = textwrap.dedent( r""" import signal @@ -95,11 +91,7 @@ def test_treespec_construct(): returncode = 0 try: with tempfile.TemporaryDirectory() as tmpdir: - subprocess.check_call( - [sys.executable, '-Walways', '-Werror', '-c', script], - cwd=tmpdir, - env=env, - ) + check_script_in_subprocess(script, cwd=tmpdir, output=None) except subprocess.CalledProcessError as ex: returncode = abs(ex.returncode) if 128 < returncode < 256: From b7d917399fd475afb2849d471ecabed7323ba5e6 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 10 Dec 2025 12:58:25 +0800 Subject: [PATCH 26/59] fix: fix docs dependency resolving --- .github/workflows/lint.yml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index be91bf1a..26a4f5dc 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -29,8 +29,11 @@ env: CLICOLOR_FORCE: "1" XDG_CACHE_HOME: "${{ github.workspace }}/.cache" PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" + UV_CACHE_DIR: "${{ github.workspace }}/.cache/pip/.uv" PRE_COMMIT_HOME: "${{ github.workspace }}/.cache/pip/.pre-commit" PIP_EXTRA_INDEX_URL: "https://download.pytorch.org/whl/cpu" + UV_INDEX: "https://download.pytorch.org/whl/cpu" + UV_INDEX_STRATEGY: "unsafe-best-match" jobs: lint: @@ -56,10 +59,10 @@ jobs: .pre-commit-config.yaml - name: Upgrade pip - run: python -m pip install --upgrade pip setuptools + run: python -m pip install --upgrade pip setuptools uv - name: Install dependencies - run: python -m pip install wheel pybind11 -r docs/requirements.txt + run: uv pip install --python=python wheel pybind11 -r docs/requirements.txt - name: Install nightly pybind11 shell: bash From d25f6dcd8442854c8962e43e87b95cd368b99882 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 10 Dec 2025 13:45:46 +0800 Subject: [PATCH 27/59] fix(workflows/lint): fix docs dependency resolving --- .github/workflows/lint.yml | 7 +++++-- include/optree/pymacros.h | 9 ++++++--- include/optree/registry.h | 18 ++++++++++++------ include/optree/synchronization.h | 2 +- include/optree/treespec.h | 8 ++++---- src/optree.cpp | 18 ++++++++---------- tests/helpers.py | 23 +++++++++++++++++++++++ tests/test_ops.py | 21 ++------------------- tests/test_treespec.py | 12 ++---------- 9 files changed, 63 insertions(+), 55 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index be91bf1a..7abf64eb 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -29,8 +29,11 @@ env: CLICOLOR_FORCE: "1" XDG_CACHE_HOME: "${{ github.workspace }}/.cache" PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" + UV_CACHE_DIR: "${{ github.workspace }}/.cache/pip/.uv" PRE_COMMIT_HOME: "${{ github.workspace }}/.cache/pip/.pre-commit" PIP_EXTRA_INDEX_URL: "https://download.pytorch.org/whl/cpu" + UV_INDEX: "https://download.pytorch.org/whl/cpu" + UV_INDEX_STRATEGY: "unsafe-best-match" jobs: lint: @@ -56,10 +59,10 @@ jobs: .pre-commit-config.yaml - name: Upgrade pip - run: python -m pip install --upgrade pip setuptools + run: python -m pip install --upgrade pip setuptools uv - name: Install dependencies - run: python -m pip install wheel pybind11 -r docs/requirements.txt + run: uv pip install --system --python=python wheel pybind11 -r docs/requirements.txt - name: Install nightly pybind11 shell: bash diff --git a/include/optree/pymacros.h b/include/optree/pymacros.h index 490a6f2f..3e417d36 100644 --- a/include/optree/pymacros.h +++ b/include/optree/pymacros.h @@ -29,6 +29,9 @@ limitations under the License. # error "pybind11 2.12.0 or newer is required." #endif +// NOLINTNEXTLINE[bugprone-macro-parentheses] +#define NONZERO_OR_EMPTY(MACRO) ((MACRO + 0 != 0) || (0 - MACRO - 1 >= 0)) + namespace py = pybind11; #if !defined(Py_ALWAYS_INLINE) @@ -83,6 +86,9 @@ Py_Declare_ID(__qualname__); // type.__qualname__ Py_Declare_ID(__name__); // type.__name__ Py_Declare_ID(sort); // list.sort Py_Declare_ID(copy); // dict.copy +Py_Declare_ID(OrderedDict); // OrderedDict +Py_Declare_ID(defaultdict); // defaultdict +Py_Declare_ID(deque); // deque Py_Declare_ID(default_factory); // defaultdict.default_factory Py_Declare_ID(maxlen); // deque.maxlen Py_Declare_ID(_fields); // namedtuple._fields @@ -91,6 +97,3 @@ Py_Declare_ID(_asdict); // namedtuple._asdict Py_Declare_ID(n_fields); // structseq.n_fields Py_Declare_ID(n_sequence_fields); // structseq.n_sequence_fields Py_Declare_ID(n_unnamed_fields); // structseq.n_unnamed_fields - -// NOLINTNEXTLINE[bugprone-macro-parentheses] -#define NONZERO_OR_EMPTY(MACRO) ((MACRO + 0 != 0) || (0 - MACRO - 1 >= 0)) diff --git a/include/optree/registry.h b/include/optree/registry.h index d96d2560..bf1afb6d 100644 --- a/include/optree/registry.h +++ b/include/optree/registry.h @@ -98,7 +98,7 @@ class PyTreeTypeRegistry { using RegistrationPtr = std::shared_ptr; - // Register a new custom type. Objects of `cls` will be treated as container node types in + // Register a new custom type. Objects of type `cls` will be treated as container node types in // PyTrees. static void Register(const py::object &cls, const py::function &flatten_func, @@ -106,9 +106,10 @@ class PyTreeTypeRegistry { const py::object &path_entry_type, const std::string ®istry_namespace = ""); + // Unregister a previously registered custom type. static void Unregister(const py::object &cls, const std::string ®istry_namespace = ""); - // Find the custom type registration for `type`. Returns nullptr if none exists. + // Find the custom type registration for `type`. Return nullptr if none exists. template [[nodiscard]] static RegistrationPtr Lookup(const py::object &cls, const std::string ®istry_namespace); @@ -136,13 +137,18 @@ class PyTreeTypeRegistry { [[nodiscard]] static RegistrationPtr UnregisterImpl(const py::object &cls, const std::string ®istry_namespace); - // Clear the registry on cleanup. + // Clear the registry on cleanup for the current interpreter. static void Clear(); - std::unordered_map m_registrations{}; - std::unordered_map, RegistrationPtr> m_named_registrations{}; + using RegistrationsMap = std::unordered_map; + using NamedRegistrationsMap = + std::unordered_map, RegistrationPtr>; + using BuiltinsTypesSet = std::unordered_set; - static inline std::unordered_set sm_builtins_types{}; + RegistrationsMap m_registrations{}; + NamedRegistrationsMap m_named_registrations{}; + + static inline BuiltinsTypesSet sm_builtins_types{}; static inline read_write_mutex sm_mutex{}; }; diff --git a/include/optree/synchronization.h b/include/optree/synchronization.h index 461e82ee..83203f48 100644 --- a/include/optree/synchronization.h +++ b/include/optree/synchronization.h @@ -60,7 +60,7 @@ using scoped_lock = std::scoped_lock; using scoped_recursive_lock = std::scoped_lock; #if (defined(__APPLE__) /* header is not available on macOS build target */ && \ - PY_VERSION_HEX < /* Python 3.12.0 */ 0x030C00F0) + PY_VERSION_HEX < 0x030C00F0 /* Python 3.12.0 */) # undef HAVE_READ_WRITE_LOCK diff --git a/include/optree/treespec.h b/include/optree/treespec.h index fff08b8d..1ff25883 100644 --- a/include/optree/treespec.h +++ b/include/optree/treespec.h @@ -288,7 +288,7 @@ class PyTreeSpec { private: using RegistrationPtr = PyTreeTypeRegistry::RegistrationPtr; - using ThreadedIdentity = std::pair; + using ThreadedIdentity = std::pair; struct Node { PyTreeKind kind = PyTreeKind::Leaf; @@ -333,16 +333,16 @@ class PyTreeSpec { // The registry namespace used to resolve the custom pytree node types. std::string m_namespace{}; - // Helper that returns the string representation of a node kind. + // Return the string representation of a node kind. [[nodiscard]] static std::string NodeKindToString(const Node &node); - // Helper that manufactures an instance of a node given its children. + // Manufacture an instance of a node given its children. [[nodiscard]] static py::object MakeNode( const Node &node, const py::object children[], // NOLINT[hicpp-avoid-c-arrays] const size_t &num_children); - // Helper that identifies the path entry class for a node. + // Identify the path entry class for a node. [[nodiscard]] static py::object GetPathEntryType(const Node &node); // Recursive helper used to implement Flatten(). diff --git a/src/optree.cpp b/src/optree.cpp index 1cfbc9bc..15e72d88 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -52,9 +52,6 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] std::string(__FILE_RELPATH_FROM_PROJECT_ROOT__) + ")"; mod.attr("Py_TPFLAGS_BASETYPE") = py::int_(Py_TPFLAGS_BASETYPE); - // NOLINTNEXTLINE[bugprone-macro-parentheses] -#define NONZERO_OR_EMPTY(MACRO) ((MACRO + 0 != 0) || (0 - MACRO - 1 >= 0)) - // Meta information during build py::dict BUILDTIME_METADATA{}; BUILDTIME_METADATA["PY_VERSION"] = py::str(PY_VERSION); @@ -99,8 +96,6 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] BUILDTIME_METADATA["GLIBCXX_USE_CXX11_ABI"] = py::bool_(false); #endif -#undef NONZERO_OR_EMPTY - mod.attr("BUILDTIME_METADATA") = std::move(BUILDTIME_METADATA); py::exec( R"py( @@ -156,7 +151,7 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] py::arg("namespace") = "") .def("flatten", &PyTreeSpec::Flatten, - "Flattens a pytree.", + "Flatten a pytree.", py::arg("tree"), py::pos_only(), py::arg("leaf_predicate") = std::nullopt, @@ -528,10 +523,13 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] } // namespace optree +// NOLINTBEGIN[cppcoreguidelines-pro-bounds-pointer-arithmetic,cppcoreguidelines-pro-type-vararg] #if PYBIND11_VERSION_HEX >= 0x020D00F0 // pybind11 2.13.0 -// NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic,cppcoreguidelines-pro-type-vararg] -PYBIND11_MODULE(_C, mod, py::mod_gil_not_used()) { optree::BuildModule(mod); } +PYBIND11_MODULE(_C, mod, py::mod_gil_not_used()) #else -// NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic,cppcoreguidelines-pro-type-vararg] -PYBIND11_MODULE(_C, mod) { optree::BuildModule(mod); } +PYBIND11_MODULE(_C, mod) #endif +// NOLINTEND[cppcoreguidelines-pro-bounds-pointer-arithmetic,cppcoreguidelines-pro-type-vararg] +{ + optree::BuildModule(mod); +} diff --git a/tests/helpers.py b/tests/helpers.py index 2da6a9b8..a45cc9cf 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -20,7 +20,9 @@ import functools import gc import itertools +import os import platform +import subprocess import sys import sysconfig import time @@ -148,6 +150,27 @@ def wrapper(*args, **kwargs): return wrapper +def check_script_in_subprocess(script, /, *, output, env=None, cwd=TEST_ROOT, rerun=1): + if env is None: + env = os.environ + env = { + key: value for key, value in env.items() if not key.startswith(('PYTHON', 'PYTEST', 'COV_')) + } + result = '' + for _ in range(rerun): + result = subprocess.check_output( + [sys.executable, '-Walways', '-Werror', '-c', script], + stderr=subprocess.STDOUT, + text=True, + encoding='utf-8', + cwd=cwd, + env=env, + ) + if output is not None: + assert result == output + return result + + MISSING = object() diff --git a/tests/test_ops.py b/tests/test_ops.py index 6ceebe4d..843ca9ae 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -19,11 +19,9 @@ import functools import itertools import operator -import os import pickle import platform import re -import subprocess import sys from collections import OrderedDict, defaultdict, deque @@ -34,7 +32,6 @@ GLOBAL_NAMESPACE, IS_LEAF_FUNCTIONS, LEAVES, - TEST_ROOT, TREE_ACCESSORS, TREE_PATHS, TREES, @@ -45,6 +42,7 @@ Py_DEBUG, always, assert_equal_type_and_value, + check_script_in_subprocess, is_list, is_none, is_tuple, @@ -60,22 +58,7 @@ @skipif_android @skipif_ios def test_import_no_warnings(): - env = { - key: value - for key, value in os.environ.items() - if not key.startswith(('PYTHON', 'PYTEST', 'COV_')) - } - assert ( - subprocess.check_output( - [sys.executable, '-Walways', '-Werror', '-c', 'import optree'], - stderr=subprocess.STDOUT, - text=True, - encoding='utf-8', - cwd=TEST_ROOT, - env=env, - ) - == '' - ) + check_script_in_subprocess('import optree', output='') def test_max_depth(): diff --git a/tests/test_treespec.py b/tests/test_treespec.py index d07f3f3f..900572e8 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -43,6 +43,7 @@ MyAnotherDict, MyDict, Py_DEBUG, + check_script_in_subprocess, disable_systrace, gc_collect, parametrize, @@ -70,11 +71,6 @@ def test_treespec_construct(): with pytest.raises(TypeError, match=re.escape('No constructor defined!')): treespec.__init__() - env = { - key: value - for key, value in os.environ.items() - if not key.startswith(('PYTHON', 'PYTEST', 'COV_')) - } script = textwrap.dedent( r""" import signal @@ -95,11 +91,7 @@ def test_treespec_construct(): returncode = 0 try: with tempfile.TemporaryDirectory() as tmpdir: - subprocess.check_call( - [sys.executable, '-Walways', '-Werror', '-c', script], - cwd=tmpdir, - env=env, - ) + check_script_in_subprocess(script, cwd=tmpdir, output=None) except subprocess.CalledProcessError as ex: returncode = abs(ex.returncode) if 128 < returncode < 256: From 0801449670ef48a013d5c99fbd7611c03627aec7 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 10 Dec 2025 15:08:33 +0800 Subject: [PATCH 28/59] fix(workflows/lint): fix docs dependency resolving --- .github/workflows/lint.yml | 7 ++----- docs/requirements.txt | 2 +- pyproject.toml | 2 +- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 7abf64eb..be91bf1a 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -29,11 +29,8 @@ env: CLICOLOR_FORCE: "1" XDG_CACHE_HOME: "${{ github.workspace }}/.cache" PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" - UV_CACHE_DIR: "${{ github.workspace }}/.cache/pip/.uv" PRE_COMMIT_HOME: "${{ github.workspace }}/.cache/pip/.pre-commit" PIP_EXTRA_INDEX_URL: "https://download.pytorch.org/whl/cpu" - UV_INDEX: "https://download.pytorch.org/whl/cpu" - UV_INDEX_STRATEGY: "unsafe-best-match" jobs: lint: @@ -59,10 +56,10 @@ jobs: .pre-commit-config.yaml - name: Upgrade pip - run: python -m pip install --upgrade pip setuptools uv + run: python -m pip install --upgrade pip setuptools - name: Install dependencies - run: uv pip install --system --python=python wheel pybind11 -r docs/requirements.txt + run: python -m pip install wheel pybind11 -r docs/requirements.txt - name: Install nightly pybind11 shell: bash diff --git a/docs/requirements.txt b/docs/requirements.txt index 85a0ae49..87c90e12 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,6 @@ --requirement ../requirements.txt -sphinx +sphinx ~= 8.0 sphinx-autoapi sphinx-autobuild sphinx-autodoc-typehints diff --git a/pyproject.toml b/pyproject.toml index 86c0838e..f4bc24fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,7 @@ test = [ "typing-extensions == 4.12.0; python_version >= '3.13' and platform_system == 'Windows'", ] docs = [ - "sphinx", + "sphinx ~= 8.0", "sphinx-autoapi", "sphinx-autobuild", "sphinx-autodoc-typehints", From e119a2f2daba319af8a87e10cc0812201cf55983 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sun, 14 Dec 2025 16:28:04 +0800 Subject: [PATCH 29/59] chore: update nightly pybind11 url --- .github/workflows/set_setup_requires.py | 3 +- tests/concurrent/test_subinterpreters.py | 68 ++++++++++++------------ 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/.github/workflows/set_setup_requires.py b/.github/workflows/set_setup_requires.py index 5bbe74c0..0e442ea1 100755 --- a/.github/workflows/set_setup_requires.py +++ b/.github/workflows/set_setup_requires.py @@ -9,6 +9,7 @@ ROOT = Path(__file__).absolute().parents[2] PYPROJECT_FILE = ROOT / 'pyproject.toml' +PYBIND11_GIT_URL = 'https://github.com/XuehaiPan/pybind11.git@subinterp-call-once-and-store' if __name__ == '__main__': @@ -17,7 +18,7 @@ PYPROJECT_FILE.write_text( data=re.sub( r'(requires\s*=\s*\[.*"\s*)\bpybind11\b[^"]*(\s*".*\])', - r'\g<1>pybind11 @ git+https://github.com/pybind/pybind11.git#egg=pybind11\g<2>', + rf'\g<1>pybind11 @ git+{PYBIND11_GIT_URL}#egg=pybind11\g<2>', string=PYPROJECT_CONTENT, ), encoding='utf-8', diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py index c62ceaf1..5ffc5088 100644 --- a/tests/concurrent/test_subinterpreters.py +++ b/tests/concurrent/test_subinterpreters.py @@ -136,8 +136,8 @@ def check_module_importable(): _ = optree.tree_flatten_with_path(tree, none_is_leaf=False) _ = optree.tree_flatten_with_path(tree, none_is_leaf=True) - # _ = optree.tree_flatten_with_accessor(tree, none_is_leaf=False) - # _ = optree.tree_flatten_with_accessor(tree, none_is_leaf=True) + _ = optree.tree_flatten_with_accessor(tree, none_is_leaf=False) + _ = optree.tree_flatten_with_accessor(tree, none_is_leaf=True) return ( optree._C.get_main_interpreter_id(), @@ -194,20 +194,20 @@ def test_import(): def test_import_in_subinterpreter_after_main(): script = textwrap.dedent( """ - import contextlib - import gc - from concurrent import interpreters + import contextlib + import gc + from concurrent import interpreters - import optree + import optree - subinterpreter = None - with contextlib.closing(interpreters.create()) as subinterpreter: - subinterpreter.exec('import optree') + subinterpreter = None + with contextlib.closing(interpreters.create()) as subinterpreter: + subinterpreter.exec('import optree') - del optree, subinterpreter - for _ in range(10): - gc.collect() - """, + del optree, subinterpreter + for _ in range(10): + gc.collect() + """, ).strip() check_script_in_subprocess(script, output='', rerun=NUM_FLAKY_RERUNS) @@ -215,19 +215,19 @@ def test_import_in_subinterpreter_after_main(): def test_import_in_subinterpreter_before_main(): script = textwrap.dedent( """ - import contextlib - import gc - from concurrent import interpreters + import contextlib + import gc + from concurrent import interpreters - subinterpreter = None - with contextlib.closing(interpreters.create()) as subinterpreter: - subinterpreter.exec('import optree') + subinterpreter = None + with contextlib.closing(interpreters.create()) as subinterpreter: + subinterpreter.exec('import optree') - import optree - del optree, subinterpreter - for _ in range(10): - gc.collect() - """, + import optree + del optree, subinterpreter + for _ in range(10): + gc.collect() + """, ).strip() check_script_in_subprocess(script, output='', rerun=NUM_FLAKY_RERUNS) @@ -235,18 +235,18 @@ def test_import_in_subinterpreter_before_main(): def test_import_in_subinterpreters_concurrently(): script = textwrap.dedent( f""" - from concurrent.futures import InterpreterPoolExecutor, as_completed + from concurrent.futures import InterpreterPoolExecutor, as_completed - def check_import(): - import optree + def check_import(): + import optree - if optree._C.get_registry_size() != 8: - raise RuntimeError('registry size mismatch') + if optree._C.get_registry_size() != 8: + raise RuntimeError('registry size mismatch') - with InterpreterPoolExecutor(max_workers={NUM_WORKERS}) as executor: - futures = [executor.submit(check_import) for _ in range({NUM_FUTURES})] - for future in as_completed(futures): - future.result() - """, + with InterpreterPoolExecutor(max_workers={NUM_WORKERS}) as executor: + futures = [executor.submit(check_import) for _ in range({NUM_FUTURES})] + for future in as_completed(futures): + future.result() + """, ).strip() check_script_in_subprocess(script, output='', rerun=NUM_FLAKY_RERUNS) From 5dcc2f88cb92175d735fa44602eb80565d81dc49 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sun, 14 Dec 2025 20:58:20 +0800 Subject: [PATCH 30/59] feat: improve sanity check error messages --- include/optree/exceptions.h | 73 +++++++++++++++++++++++++++++----- include/optree/registry.h | 12 +++--- src/registry.cpp | 3 +- src/treespec/serialization.cpp | 2 +- src/treespec/traversal.cpp | 2 +- src/treespec/treespec.cpp | 2 +- src/treespec/unflatten.cpp | 2 +- 7 files changed, 74 insertions(+), 22 deletions(-) diff --git a/include/optree/exceptions.h b/include/optree/exceptions.h index 9b2906cc..469ea0bb 100644 --- a/include/optree/exceptions.h +++ b/include/optree/exceptions.h @@ -17,11 +17,12 @@ limitations under the License. #pragma once -#include // std::size_t -#include // std::optional, std::nullopt -#include // std::ostringstream -#include // std::logic_error -#include // std::string, std::char_traits +#include // std::size_t +#include // std::optional, std::nullopt +#include // std::ostringstream +#include // std::logic_error +#include // std::string, std::char_traits, std::to_string +#include // std::declval, std::void_t, std::{true,false}_type namespace optree { @@ -56,6 +57,26 @@ class InternalError : public std::logic_error { }()) {} }; +// SFINAE helper to detect if std::to_string is available for a type +template +struct has_to_string : std::false_type {}; + +template +struct has_to_string()))>> : std::true_type { +}; + +template +inline constexpr bool has_to_string_v = has_to_string::value; + +// Convert value to string if possible, otherwise return a placeholder. +template +inline std::string to_string([[maybe_unused]] const T &value) { + if constexpr (has_to_string_v) { + return std::to_string(value); + } + return ""; +} + #define VA_FUNC2_(__0, __1, NAME, ...) NAME #define VA_FUNC3_(__0, __1, __2, NAME, ...) NAME @@ -81,13 +102,43 @@ class InternalError : public std::logic_error { #define EXPECT_(...) \ VA_FUNC3_(__0 __VA_OPT__(, ) __VA_ARGS__, EXPECT2_, EXPECT1_, EXPECT0_)(__VA_ARGS__) +#define EXPECT_OP0_(a, b, op) \ + { \ + const auto &__a = (a); \ + const auto &__b = (b); \ + EXPECT2_((__a)op(__b), \ + "Expected `(" #a ") " #op " (" #b ")`, but got `! (" + optree::to_string(__a) + \ + ") " #op " (" + optree::to_string(__b) + ")`."); \ + } +#define EXPECT_OP1_(a, b, op, nop) \ + { \ + const auto &__a = (a); \ + const auto &__b = (b); \ + EXPECT2_((__a)op(__b), \ + "Expected `(" #a ") " #op " (" #b ")`, but got `(" + optree::to_string(__a) + \ + ") " #nop " (" + optree::to_string(__b) + ")`."); \ + } +#define EXPECT_OP2_(a, b, op, nop, message) \ + { \ + const auto &__a = (a); \ + const auto &__b = (b); \ + EXPECT2_((__a)op(__b), \ + std::string(message) + " (Expected `(" #a ") " #op " (" #b ")`, but got `(" + \ + optree::to_string(__a) + ") " #nop " (" + optree::to_string(__b) + ")`.)"); \ + } +#define EXPECT_OP_(a, b, op, ...) \ + VA_FUNC3_(__0 __VA_OPT__(, ) __VA_ARGS__, \ + EXPECT_OP2_, \ + EXPECT_OP1_, \ + EXPECT_OP0_)(a, b, op __VA_OPT__(, ) __VA_ARGS__) + #define EXPECT_TRUE(condition, ...) EXPECT_((condition)__VA_OPT__(, ) __VA_ARGS__) #define EXPECT_FALSE(condition, ...) EXPECT_(!(condition)__VA_OPT__(, ) __VA_ARGS__) -#define EXPECT_EQ(a, b, ...) EXPECT_((a) == (b)__VA_OPT__(, ) __VA_ARGS__) -#define EXPECT_NE(a, b, ...) EXPECT_((a) != (b)__VA_OPT__(, ) __VA_ARGS__) -#define EXPECT_LT(a, b, ...) EXPECT_((a) < (b)__VA_OPT__(, ) __VA_ARGS__) -#define EXPECT_LE(a, b, ...) EXPECT_((a) <= (b)__VA_OPT__(, ) __VA_ARGS__) -#define EXPECT_GT(a, b, ...) EXPECT_((a) > (b)__VA_OPT__(, ) __VA_ARGS__) -#define EXPECT_GE(a, b, ...) EXPECT_((a) >= (b)__VA_OPT__(, ) __VA_ARGS__) +#define EXPECT_EQ(a, b, ...) EXPECT_OP_(a, b, ==, != __VA_OPT__(, ) __VA_ARGS__) +#define EXPECT_NE(a, b, ...) EXPECT_OP_(a, b, !=, == __VA_OPT__(, ) __VA_ARGS__) +#define EXPECT_LT(a, b, ...) EXPECT_OP_(a, b, <, >= __VA_OPT__(, ) __VA_ARGS__) +#define EXPECT_LE(a, b, ...) EXPECT_OP_(a, b, <=, > __VA_OPT__(, ) __VA_ARGS__) +#define EXPECT_GT(a, b, ...) EXPECT_OP_(a, b, >, <= __VA_OPT__(, ) __VA_ARGS__) +#define EXPECT_GE(a, b, ...) EXPECT_OP_(a, b, >=, < __VA_OPT__(, ) __VA_ARGS__) } // namespace optree diff --git a/include/optree/registry.h b/include/optree/registry.h index f66218df..381f7bc9 100644 --- a/include/optree/registry.h +++ b/include/optree/registry.h @@ -141,12 +141,6 @@ class PyTreeTypeRegistry { // Get the number of alive interpreters that have seen the registry. [[nodiscard]] static inline Py_ALWAYS_INLINE ssize_t GetNumInterpretersAlive() { - const scoped_read_lock lock{sm_mutex}; - return sm_num_interpreters_seen; - } - - // Get the number of interpreters that have seen the registry. - [[nodiscard]] static inline Py_ALWAYS_INLINE ssize_t GetNumInterpretersSeen() { const scoped_read_lock lock{sm_mutex}; EXPECT_EQ(py::ssize_t_cast(sm_builtins_types.size()), sm_num_interpreters_alive, @@ -155,6 +149,12 @@ class PyTreeTypeRegistry { return sm_num_interpreters_alive; } + // Get the number of interpreters that have seen the registry. + [[nodiscard]] static inline Py_ALWAYS_INLINE ssize_t GetNumInterpretersSeen() { + const scoped_read_lock lock{sm_mutex}; + return sm_num_interpreters_seen; + } + // Get the IDs of alive interpreters that have seen the registry. [[nodiscard]] static inline Py_ALWAYS_INLINE std::unordered_set GetAliveInterpreterIDs() { diff --git a/src/registry.cpp b/src/registry.cpp index 2204bf8f..77e0d8fc 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -390,7 +390,8 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( EXPECT_LE(builtins_types.size(), registrations1.size()); EXPECT_EQ(registrations1.size(), registrations2.size() + 1); EXPECT_EQ(named_registrations1.size(), named_registrations2.size()); - EXPECT_EQ(&builtins_types, &builtins_types_); + EXPECT_EQ(reinterpret_cast(&builtins_types), + reinterpret_cast(&builtins_types)); #if defined(Py_DEBUG) for (const auto &cls : builtins_types) { diff --git a/src/treespec/serialization.cpp b/src/treespec/serialization.cpp index b80770c9..06556588 100644 --- a/src/treespec/serialization.cpp +++ b/src/treespec/serialization.cpp @@ -243,7 +243,7 @@ std::string PyTreeSpec::ToStringImpl() const { agenda.emplace_back(sstream.str()); } - EXPECT_EQ(agenda.size(), 1, "PyTreeSpec traversal did not yield a singleton."); + EXPECT_EQ(agenda.size(), 1U, "PyTreeSpec traversal did not yield a singleton."); std::ostringstream oss{}; oss << "PyTreeSpec(" << agenda.back(); if (m_none_is_leaf) [[unlikely]] { diff --git a/src/treespec/traversal.cpp b/src/treespec/traversal.cpp index 36917362..e22ba47e 100644 --- a/src/treespec/traversal.cpp +++ b/src/treespec/traversal.cpp @@ -248,7 +248,7 @@ py::object PyTreeSpec::WalkImpl(const py::iterable &leaves, throw py::value_error("Too many leaves for PyTreeSpec."); } - EXPECT_EQ(agenda.size(), 1, "PyTreeSpec traversal did not yield a singleton."); + EXPECT_EQ(agenda.size(), 1U, "PyTreeSpec traversal did not yield a singleton."); return agenda.back(); } diff --git a/src/treespec/treespec.cpp b/src/treespec/treespec.cpp index d952c82c..165ee665 100644 --- a/src/treespec/treespec.cpp +++ b/src/treespec/treespec.cpp @@ -551,7 +551,7 @@ std::unique_ptr PyTreeSpec::Transform(const std::optionalm_traversal.back(); diff --git a/src/treespec/unflatten.cpp b/src/treespec/unflatten.cpp index db43a83e..adc9e50e 100644 --- a/src/treespec/unflatten.cpp +++ b/src/treespec/unflatten.cpp @@ -75,7 +75,7 @@ py::object PyTreeSpec::UnflattenImpl(const Span &leaves) const { oss << "Too many leaves for PyTreeSpec; expected: " << GetNumLeaves() << "."; throw py::value_error(oss.str()); } - EXPECT_EQ(agenda.size(), 1, "PyTreeSpec traversal did not yield a singleton."); + EXPECT_EQ(agenda.size(), 1U, "PyTreeSpec traversal did not yield a singleton."); return agenda.back(); } From 1c43acc3d2abb86e14d868496f037ff5a7a9480f Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sun, 14 Dec 2025 22:29:41 +0800 Subject: [PATCH 31/59] revert --- include/optree/pymacros.h | 45 +++---- include/optree/pytypes.h | 18 +-- include/optree/registry.h | 32 +---- src/optree.cpp | 10 +- src/registry.cpp | 264 ++++++++++++++++---------------------- 5 files changed, 137 insertions(+), 232 deletions(-) diff --git a/include/optree/pymacros.h b/include/optree/pymacros.h index 6759508f..05a2e25a 100644 --- a/include/optree/pymacros.h +++ b/include/optree/pymacros.h @@ -72,35 +72,22 @@ inline constexpr Py_ALWAYS_INLINE bool Py_IsConstant(PyObject *x) noexcept { } #define Py_IsConstant(x) Py_IsConstant(x) -#if defined(OPTREE_HAS_SUBINTERPRETER_SUPPORT) -# define Py_Declare_ID(name) \ - namespace { \ - [[nodiscard]] inline PyObject *Py_ID_##name() { \ - PyObject * const ptr = PyUnicode_InternFromString(#name); \ - if (ptr == nullptr) [[unlikely]] { \ - throw py::error_already_set(); \ - } \ - return ptr; \ - } \ - } // namespace -#else -# define Py_Declare_ID(name) \ - namespace { \ - [[nodiscard]] inline PyObject *Py_ID_##name() { \ - PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store storage; \ - return storage \ - .call_once_and_store_result([]() -> PyObject * { \ - PyObject * const ptr = PyUnicode_InternFromString(#name); \ - if (ptr == nullptr) [[unlikely]] { \ - throw py::error_already_set(); \ - } \ - Py_INCREF(ptr); /* leak a reference on purpose */ \ - return ptr; \ - }) \ - .get_stored(); \ - } \ - } // namespace -#endif +#define Py_Declare_ID(name) \ + namespace { \ + [[nodiscard]] inline PyObject *Py_ID_##name() { \ + PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store storage; \ + return storage \ + .call_once_and_store_result([]() -> PyObject * { \ + PyObject * const ptr = PyUnicode_InternFromString(#name); \ + if (ptr == nullptr) [[unlikely]] { \ + throw py::error_already_set(); \ + } \ + Py_INCREF(ptr); /* leak a reference on purpose */ \ + return ptr; \ + }) \ + .get_stored(); \ + } \ + } // namespace #define Py_Get_ID(name) (::Py_ID_##name()) diff --git a/include/optree/pytypes.h b/include/optree/pytypes.h index b3cfbabd..b031929b 100644 --- a/include/optree/pytypes.h +++ b/include/optree/pytypes.h @@ -71,35 +71,25 @@ inline const py::object &ImportOrderedDict() { PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store storage; return storage .call_once_and_store_result([]() -> py::object { - return py::getattr(py::module_::import("collections"), Py_Get_ID(OrderedDict)); + return py::getattr(py::module_::import("collections"), "OrderedDict"); }) .get_stored(); } -#if defined(OPTREE_HAS_SUBINTERPRETER_SUPPORT) -inline py::object ImportDefaultDict() { - return py::getattr(py::module_::import("collections"), Py_Get_ID(defaultdict)); -} -inline py::object ImportDeque() { - return py::getattr(py::module_::import("collections"), Py_Get_ID(deque)); -} -#else inline const py::object &ImportDefaultDict() { PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store storage; return storage .call_once_and_store_result([]() -> py::object { - return py::getattr(py::module_::import("collections"), Py_Get_ID(defaultdict)); + return py::getattr(py::module_::import("collections"), "defaultdict"); }) .get_stored(); } inline const py::object &ImportDeque() { PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store storage; return storage - .call_once_and_store_result([]() -> py::object { - return py::getattr(py::module_::import("collections"), Py_Get_ID(deque)); - }) + .call_once_and_store_result( + []() -> py::object { return py::getattr(py::module_::import("collections"), "deque"); }) .get_stored(); } -#endif inline Py_ALWAYS_INLINE py::ssize_t TupleGetSize(const py::handle &tuple) { return PyTuple_GET_SIZE(tuple.ptr()); diff --git a/include/optree/registry.h b/include/optree/registry.h index 381f7bc9..ec5dbb39 100644 --- a/include/optree/registry.h +++ b/include/optree/registry.h @@ -142,11 +142,7 @@ class PyTreeTypeRegistry { // Get the number of alive interpreters that have seen the registry. [[nodiscard]] static inline Py_ALWAYS_INLINE ssize_t GetNumInterpretersAlive() { const scoped_read_lock lock{sm_mutex}; - EXPECT_EQ(py::ssize_t_cast(sm_builtins_types.size()), - sm_num_interpreters_alive, - "The number of alive interpreters should match the size of the " - "interpreter-scoped registered types map."); - return sm_num_interpreters_alive; + return py::ssize_t_cast(sm_alive_interpids.size()); } // Get the number of interpreters that have seen the registry. @@ -159,15 +155,7 @@ class PyTreeTypeRegistry { [[nodiscard]] static inline Py_ALWAYS_INLINE std::unordered_set GetAliveInterpreterIDs() { const scoped_read_lock lock{sm_mutex}; - EXPECT_EQ(py::ssize_t_cast(sm_builtins_types.size()), - sm_num_interpreters_alive, - "The number of alive interpreters should match the size of the " - "interpreter-scoped registered types map."); - std::unordered_set interpids; - for (const auto &[interpid, _] : sm_builtins_types) { - interpids.insert(interpid); - } - return interpids; + return sm_alive_interpids; } friend void BuildModule(py::module_ &mod); // NOLINT[runtime/references] @@ -188,7 +176,7 @@ class PyTreeTypeRegistry { const std::string ®istry_namespace); // Initialize the registry for the current interpreter. - void Init(); + static void Init(); // Clear the registry on cleanup for the current interpreter. static void Clear(); @@ -198,18 +186,12 @@ class PyTreeTypeRegistry { std::unordered_map, RegistrationPtr>; using BuiltinsTypesSet = std::unordered_set; - // Get the registrations for the current Python interpreter. - [[nodiscard]] inline Py_ALWAYS_INLINE std:: - tuple - GetRegistrationsForCurrentPyInterpreterLocked() const; + RegistrationsMap m_registrations{}; + NamedRegistrationsMap m_named_registrations{}; + BuiltinsTypesSet m_builtins_types{}; - bool m_none_is_leaf = false; - std::unordered_map m_registrations{}; - std::unordered_map m_named_registrations{}; - - static inline std::unordered_map sm_builtins_types{}; + static inline std::unordered_set sm_alive_interpids{}; static inline read_write_mutex sm_mutex{}; - static inline ssize_t sm_num_interpreters_alive = 0; static inline ssize_t sm_num_interpreters_seen = 0; }; diff --git a/src/optree.cpp b/src/optree.cpp index 976f96d0..09659ba8 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -541,15 +541,7 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] PyType_Modified(PyTreeSpec_Type); PyType_Modified(PyTreeIter_Type); - { - const scoped_write_lock interp_lock{PyTreeTypeRegistry::sm_mutex}; - ++PyTreeTypeRegistry::sm_num_interpreters_alive; - ++PyTreeTypeRegistry::sm_num_interpreters_seen; - } - PyTreeTypeRegistry::GetSingleton().Init(); - PyTreeTypeRegistry::GetSingleton().Init(); - py::getattr(py::module_::import("atexit"), - "register")(py::cpp_function(&PyTreeTypeRegistry::Clear)); + PyTreeTypeRegistry::Init(); } } // namespace optree diff --git a/src/registry.cpp b/src/registry.cpp index 77e0d8fc..fa8f401d 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -35,7 +35,34 @@ template return storage .call_once_and_store_result([]() -> PyTreeTypeRegistry { PyTreeTypeRegistry registry{}; - registry.m_none_is_leaf = NoneIsLeaf; + + const auto add_builtin_type = [®istry](const py::object &cls, + const PyTreeKind &kind) -> void { + EXPECT_TRUE(registry.m_builtins_types.emplace(cls).second, + "PyTree type " + PyRepr(cls) + + " is already registered in the built-in types set."); + cls.inc_ref(); + if (!NoneIsLeaf || kind != PyTreeKind::None) { + auto registration = + std::make_shared>(); + registration->kind = kind; + registration->type = py::reinterpret_borrow(cls); + EXPECT_TRUE( + registry.m_registrations.emplace(cls, std::move(registration)).second, + "PyTree type " + PyRepr(cls) + + " is already registered in the global namespace."); + if constexpr (!NoneIsLeaf) { + cls.inc_ref(); + } + } + }; + add_builtin_type(PyNoneTypeObject, PyTreeKind::None); + add_builtin_type(PyTupleTypeObject, PyTreeKind::Tuple); + add_builtin_type(PyListTypeObject, PyTreeKind::List); + add_builtin_type(PyDictTypeObject, PyTreeKind::Dict); + add_builtin_type(PyOrderedDictTypeObject, PyTreeKind::OrderedDict); + add_builtin_type(PyDefaultDictTypeObject, PyTreeKind::DefaultDict); + add_builtin_type(PyDequeTypeObject, PyTreeKind::Deque); return registry; }) .get_stored(); @@ -44,94 +71,11 @@ template template PyTreeTypeRegistry &PyTreeTypeRegistry::GetSingleton(); template PyTreeTypeRegistry &PyTreeTypeRegistry::GetSingleton(); -std::tuple -PyTreeTypeRegistry::GetRegistrationsForCurrentPyInterpreterLocked() const { - const auto interpid = GetCurrentPyInterpreterID(); - - EXPECT_NE(m_registrations.find(interpid), - m_registrations.end(), - "Interpreter ID " + std::to_string(interpid) + " not found in `m_registrations`."); - EXPECT_NE( - m_named_registrations.find(interpid), - m_named_registrations.end(), - "Interpreter ID " + std::to_string(interpid) + " not found in `m_named_registrations`."); - EXPECT_NE(sm_builtins_types.find(interpid), - sm_builtins_types.end(), - "Interpreter ID " + std::to_string(interpid) + " not found in `sm_builtins_types`."); - - // NOLINTBEGIN[cppcoreguidelines-pro-type-const-cast] - return {const_cast(m_registrations.at(interpid)), - const_cast(m_named_registrations.at(interpid)), - const_cast(sm_builtins_types.at(interpid))}; - // NOLINTEND[cppcoreguidelines-pro-type-const-cast] -} - -void PyTreeTypeRegistry::Init() { - const scoped_write_lock lock{sm_mutex}; - - const auto interpid = GetCurrentPyInterpreterID(); - - EXPECT_EQ(m_registrations.find(interpid), - m_registrations.end(), - "Interpreter ID " + std::to_string(interpid) + - " is already initialized in `m_registrations`."); - EXPECT_EQ(m_named_registrations.find(interpid), - m_named_registrations.end(), - "Interpreter ID " + std::to_string(interpid) + - " is already initialized in `m_named_registrations`."); - if (!m_none_is_leaf) [[likely]] { - EXPECT_EQ(sm_builtins_types.find(interpid), - sm_builtins_types.end(), - "Interpreter ID " + std::to_string(interpid) + - " is already initialized in `sm_builtins_types`."); - } else { - EXPECT_NE(sm_builtins_types.find(interpid), - sm_builtins_types.end(), - "Interpreter ID " + std::to_string(interpid) + - " is not initialized in `sm_builtins_types`."); - } - - auto ®istrations = m_registrations.try_emplace(interpid).first->second; - auto &named_registrations = m_named_registrations.try_emplace(interpid).first->second; - auto &builtins_types = sm_builtins_types.try_emplace(interpid).first->second; - - (void)named_registrations; // silence unused variable warning - - const auto add_builtin_type = - [®istrations, &builtins_types](const py::object &cls, const PyTreeKind &kind) -> void { - auto registration = std::make_shared>(); - registration->kind = kind; - registration->type = py::reinterpret_borrow(cls); - EXPECT_TRUE( - registrations.emplace(cls, std::move(registration)).second, - "PyTree type " + PyRepr(cls) + " is already registered in the global namespace."); - if (builtins_types.emplace(cls).second) [[likely]] { - cls.inc_ref(); - } - }; - if (!m_none_is_leaf) [[likely]] { - add_builtin_type(PyNoneTypeObject, PyTreeKind::None); - } - add_builtin_type(PyTupleTypeObject, PyTreeKind::Tuple); - add_builtin_type(PyListTypeObject, PyTreeKind::List); - add_builtin_type(PyDictTypeObject, PyTreeKind::Dict); - add_builtin_type(PyOrderedDictTypeObject, PyTreeKind::OrderedDict); - add_builtin_type(PyDefaultDictTypeObject, PyTreeKind::DefaultDict); - add_builtin_type(PyDequeTypeObject, PyTreeKind::Deque); -} - ssize_t PyTreeTypeRegistry::Size(const std::optional ®istry_namespace) const { const scoped_read_lock lock{sm_mutex}; - const auto &[registrations, named_registrations, builtins_types] = - GetRegistrationsForCurrentPyInterpreterLocked(); - - (void)builtins_types; // silence unused variable warning - - ssize_t count = py::ssize_t_cast(registrations.size()); - for (const auto &[named_type, _] : named_registrations) { + ssize_t count = py::ssize_t_cast(m_registrations.size()); + for (const auto &[named_type, _] : m_named_registrations) { if (!registry_namespace || named_type.first == *registry_namespace) [[likely]] { ++count; } @@ -145,12 +89,9 @@ template const py::function &unflatten_func, const py::object &path_entry_type, const std::string ®istry_namespace) { - const auto ®istry = GetSingleton(); - - auto [registrations, named_registrations, builtins_types] = - registry.GetRegistrationsForCurrentPyInterpreterLocked(); + auto ®istry = GetSingleton(); - if (builtins_types.find(cls) != builtins_types.end()) [[unlikely]] { + if (registry.m_builtins_types.find(cls) != registry.m_builtins_types.end()) [[unlikely]] { throw py::value_error("PyTree type " + PyRepr(cls) + " is a built-in type and cannot be re-registered."); } @@ -162,7 +103,7 @@ template registration->unflatten_func = py::reinterpret_borrow(unflatten_func); registration->path_entry_type = py::reinterpret_borrow(path_entry_type); if (registry_namespace.empty()) [[unlikely]] { - if (!registrations.emplace(cls, std::move(registration)).second) [[unlikely]] { + if (!registry.m_registrations.emplace(cls, std::move(registration)).second) [[unlikely]] { throw py::value_error("PyTree type " + PyRepr(cls) + " is already registered in the global namespace."); } @@ -184,7 +125,7 @@ template /*stack_level=*/2); } } else [[likely]] { - if (!named_registrations + if (!registry.m_named_registrations .emplace(std::make_pair(registry_namespace, cls), std::move(registration)) .second) [[unlikely]] { std::ostringstream oss{}; @@ -243,19 +184,16 @@ template /*static*/ PyTreeTypeRegistry::RegistrationPtr PyTreeTypeRegistry::UnregisterImpl( const py::object &cls, const std::string ®istry_namespace) { - const auto ®istry = GetSingleton(); - - auto [registrations, named_registrations, builtins_types] = - registry.GetRegistrationsForCurrentPyInterpreterLocked(); + auto ®istry = GetSingleton(); - if (builtins_types.find(cls) != builtins_types.end()) [[unlikely]] { + if (registry.m_builtins_types.find(cls) != registry.m_builtins_types.end()) [[unlikely]] { throw py::value_error("PyTree type " + PyRepr(cls) + " is a built-in type and cannot be unregistered."); } if (registry_namespace.empty()) [[unlikely]] { - const auto it = registrations.find(cls); - if (it == registrations.end()) [[unlikely]] { + const auto it = registry.m_registrations.find(cls); + if (it == registry.m_registrations.end()) [[unlikely]] { std::ostringstream oss{}; oss << "PyTree type " << PyRepr(cls) << " "; if (IsStructSequenceClass(cls)) [[unlikely]] { @@ -270,11 +208,12 @@ template throw py::value_error(oss.str()); } RegistrationPtr registration = it->second; - registrations.erase(it); + registry.m_registrations.erase(it); return registration; } else [[likely]] { - const auto named_it = named_registrations.find(std::make_pair(registry_namespace, cls)); - if (named_it == named_registrations.end()) [[unlikely]] { + const auto named_it = + registry.m_named_registrations.find(std::make_pair(registry_namespace, cls)); + if (named_it == registry.m_named_registrations.end()) [[unlikely]] { std::ostringstream oss{}; oss << "PyTree type " << PyRepr(cls) << " "; if (IsStructSequenceClass(cls)) [[unlikely]] { @@ -290,7 +229,7 @@ template throw py::value_error(oss.str()); } RegistrationPtr registration = named_it->second; - named_registrations.erase(named_it); + registry.m_named_registrations.erase(named_it); return registration; } } @@ -318,18 +257,15 @@ template const scoped_read_lock lock{sm_mutex}; const auto ®istry = GetSingleton(); - - const auto &[registrations, named_registrations, _] = - registry.GetRegistrationsForCurrentPyInterpreterLocked(); - if (!registry_namespace.empty()) [[unlikely]] { - const auto named_it = named_registrations.find(std::make_pair(registry_namespace, cls)); - if (named_it != named_registrations.end()) [[likely]] { + const auto named_it = + registry.m_named_registrations.find(std::make_pair(registry_namespace, cls)); + if (named_it != registry.m_named_registrations.end()) [[likely]] { return named_it->second; } } - const auto it = registrations.find(cls); - return it != registrations.end() ? it->second : nullptr; + const auto it = registry.m_registrations.find(cls); + return it != registry.m_registrations.end() ? it->second : nullptr; } template PyTreeTypeRegistry::RegistrationPtr PyTreeTypeRegistry::Lookup( @@ -373,33 +309,57 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( PyTreeTypeRegistry::RegistrationPtr &, // NOLINT[runtime/references] const std::string &); +/*static*/ void PyTreeTypeRegistry::Init() { + const scoped_write_lock lock{sm_mutex}; + + const auto interpid = GetCurrentPyInterpreterID(); + + ++sm_num_interpreters_seen; + EXPECT_TRUE( + sm_alive_interpids.insert(interpid).second, + "The current interpreter ID should not be already present in the alive interpreters " + "set."); + + auto ®istry1 = GetSingleton(); + auto ®istry2 = GetSingleton(); + + EXPECT_EQ(registry1.m_builtins_types.size(), registry2.m_builtins_types.size()); + EXPECT_LE(registry1.m_builtins_types.size(), registry1.m_registrations.size()); + EXPECT_EQ(registry1.m_registrations.size(), registry2.m_registrations.size() + 1); + EXPECT_EQ(registry1.m_named_registrations.size(), registry2.m_named_registrations.size()); + + py::getattr(py::module_::import("atexit"), "register")(py::cpp_function(&Clear)); +} + // NOLINTNEXTLINE[readability-function-cognitive-complexity] /*static*/ void PyTreeTypeRegistry::Clear() { const scoped_write_lock lock{sm_mutex}; const auto interpid = GetCurrentPyInterpreterID(); + EXPECT_NE(sm_alive_interpids.find(interpid), + sm_alive_interpids.end(), + "The current interpreter ID should be present in the alive interpreters set."); + sm_alive_interpids.erase(interpid); + auto ®istry1 = GetSingleton(); auto ®istry2 = GetSingleton(); - auto [registrations1, named_registrations1, builtins_types] = - registry1.GetRegistrationsForCurrentPyInterpreterLocked(); - auto [registrations2, named_registrations2, builtins_types_] = - registry2.GetRegistrationsForCurrentPyInterpreterLocked(); - - EXPECT_LE(builtins_types.size(), registrations1.size()); - EXPECT_EQ(registrations1.size(), registrations2.size() + 1); - EXPECT_EQ(named_registrations1.size(), named_registrations2.size()); - EXPECT_EQ(reinterpret_cast(&builtins_types), - reinterpret_cast(&builtins_types)); + EXPECT_EQ(registry1.m_builtins_types.size(), registry2.m_builtins_types.size()); + EXPECT_LE(registry1.m_builtins_types.size(), registry1.m_registrations.size()); + EXPECT_EQ(registry1.m_registrations.size(), registry2.m_registrations.size() + 1); + EXPECT_EQ(registry1.m_named_registrations.size(), registry2.m_named_registrations.size()); #if defined(Py_DEBUG) - for (const auto &cls : builtins_types) { - EXPECT_NE(registrations1.find(cls), registrations1.end()); + for (const auto &cls : registry1.m_builtins_types) { + EXPECT_NE(registry1.m_registrations.find(cls), registry1.m_registrations.end()); } - for (const auto &[cls2, registration2] : registrations2) { - const auto it1 = registrations1.find(cls2); - EXPECT_NE(it1, registrations1.end()); + for (const auto &cls : registry2.m_builtins_types) { + EXPECT_NE(registry2.m_registrations.find(cls), registry2.m_registrations.end()); + } + for (const auto &[cls2, registration2] : registry2.m_registrations) { + const auto it1 = registry1.m_registrations.find(cls2); + EXPECT_NE(it1, registry1.m_registrations.end()); const auto ®istration1 = it1->second; EXPECT_TRUE(registration1->type.is(registration2->type)); @@ -407,9 +367,9 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( EXPECT_TRUE(registration1->unflatten_func.is(registration2->unflatten_func)); EXPECT_TRUE(registration1->path_entry_type.is(registration2->path_entry_type)); } - for (const auto &[named_cls2, registration2] : named_registrations2) { - const auto it1 = named_registrations1.find(named_cls2); - EXPECT_NE(it1, named_registrations1.end()); + for (const auto &[named_cls2, registration2] : registry2.m_named_registrations) { + const auto it1 = registry1.m_named_registrations.find(named_cls2); + EXPECT_NE(it1, registry1.m_named_registrations.end()); const auto ®istration1 = it1->second; EXPECT_TRUE(registration1->type.is(registration2->type)); @@ -419,34 +379,28 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( } #endif - EXPECT_EQ(py::ssize_t_cast(sm_builtins_types.size()), sm_num_interpreters_alive); - - for (const auto &[_, registration] : registrations1) { - registration->type.dec_ref(); - registration->flatten_func.dec_ref(); - registration->unflatten_func.dec_ref(); - registration->path_entry_type.dec_ref(); + for (const auto &cls : registry1.m_builtins_types) { + cls.dec_ref(); } - for (const auto &[_, registration] : named_registrations1) { - registration->type.dec_ref(); - registration->flatten_func.dec_ref(); - registration->unflatten_func.dec_ref(); - registration->path_entry_type.dec_ref(); + for (const auto &[_, registration1] : registry1.m_registrations) { + registration1->type.dec_ref(); + registration1->flatten_func.dec_ref(); + registration1->unflatten_func.dec_ref(); + registration1->path_entry_type.dec_ref(); + } + for (const auto &[_, registration1] : registry1.m_named_registrations) { + registration1->type.dec_ref(); + registration1->flatten_func.dec_ref(); + registration1->unflatten_func.dec_ref(); + registration1->path_entry_type.dec_ref(); } - builtins_types.clear(); - registrations1.clear(); - named_registrations1.clear(); - registrations2.clear(); - named_registrations2.clear(); - - sm_builtins_types.erase(interpid); - registry1.m_registrations.erase(interpid); - registry1.m_named_registrations.erase(interpid); - registry2.m_registrations.erase(interpid); - registry2.m_named_registrations.erase(interpid); - - --sm_num_interpreters_alive; + registry1.m_builtins_types.clear(); + registry1.m_registrations.clear(); + registry1.m_named_registrations.clear(); + registry2.m_builtins_types.clear(); + registry2.m_registrations.clear(); + registry2.m_named_registrations.clear(); } } // namespace optree From 90afec28986547dfe336d1cf5350cc4d3046b3c4 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 15 Dec 2025 01:22:00 +0800 Subject: [PATCH 32/59] update --- src/registry.cpp | 7 +- tests/concurrent/test_subinterpreters.py | 179 +++++++++++++++++------ tests/helpers.py | 28 ++-- 3 files changed, 156 insertions(+), 58 deletions(-) diff --git a/src/registry.cpp b/src/registry.cpp index fa8f401d..62cd85ae 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -353,9 +353,14 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( #if defined(Py_DEBUG) for (const auto &cls : registry1.m_builtins_types) { EXPECT_NE(registry1.m_registrations.find(cls), registry1.m_registrations.end()); + EXPECT_NE(registry2.m_builtins_types.find(cls), registry2.m_builtins_types.end()); } for (const auto &cls : registry2.m_builtins_types) { - EXPECT_NE(registry2.m_registrations.find(cls), registry2.m_registrations.end()); + if (cls.is(PyNoneTypeObject)) [[unlikely]] { + EXPECT_EQ(registry2.m_registrations.find(cls), registry2.m_registrations.end()); + } else [[likely]] { + EXPECT_NE(registry2.m_registrations.find(cls), registry2.m_registrations.end()); + } } for (const auto &[cls2, registration2] : registry2.m_registrations) { const auto it1 = registry1.m_registrations.find(cls2); diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py index 5ffc5088..a5668bdd 100644 --- a/tests/concurrent/test_subinterpreters.py +++ b/tests/concurrent/test_subinterpreters.py @@ -192,61 +192,152 @@ def test_import(): def test_import_in_subinterpreter_after_main(): - script = textwrap.dedent( - """ - import contextlib - import gc - from concurrent import interpreters + check_script_in_subprocess( + textwrap.dedent( + """ + import contextlib + import gc + from concurrent import interpreters - import optree + import optree - subinterpreter = None - with contextlib.closing(interpreters.create()) as subinterpreter: - subinterpreter.exec('import optree') + subinterpreter = None + with contextlib.closing(interpreters.create()) as subinterpreter: + subinterpreter.exec('import optree') + + del optree, subinterpreter + for _ in range(10): + gc.collect() + """, + ).strip(), + output='', + rerun=NUM_FLAKY_RERUNS, + ) + + check_script_in_subprocess( + textwrap.dedent( + f""" + import contextlib + import gc + from concurrent import interpreters - del optree, subinterpreter - for _ in range(10): - gc.collect() - """, - ).strip() - check_script_in_subprocess(script, output='', rerun=NUM_FLAKY_RERUNS) + import optree + + subinterpreter = subinterpreters = stack = None + with contextlib.ExitStack() as stack: + subinterpreters = [ + stack.enter_context(contextlib.closing(interpreters.create())) + for _ in range({NUM_FUTURES}) + ] + for subinterpreter in subinterpreters: + subinterpreter.exec('import optree') + + del optree, subinterpreter, subinterpreters, stack + for _ in range(10): + gc.collect() + """, + ).strip(), + output='', + rerun=NUM_FLAKY_RERUNS, + ) def test_import_in_subinterpreter_before_main(): - script = textwrap.dedent( - """ - import contextlib - import gc - from concurrent import interpreters + check_script_in_subprocess( + textwrap.dedent( + """ + import contextlib + import gc + from concurrent import interpreters - subinterpreter = None - with contextlib.closing(interpreters.create()) as subinterpreter: - subinterpreter.exec('import optree') + subinterpreter = None + with contextlib.closing(interpreters.create()) as subinterpreter: + subinterpreter.exec('import optree') - import optree - del optree, subinterpreter - for _ in range(10): - gc.collect() - """, - ).strip() - check_script_in_subprocess(script, output='', rerun=NUM_FLAKY_RERUNS) + import optree + del optree, subinterpreter + for _ in range(10): + gc.collect() + """, + ).strip(), + output='', + rerun=NUM_FLAKY_RERUNS, + ) -def test_import_in_subinterpreters_concurrently(): - script = textwrap.dedent( - f""" - from concurrent.futures import InterpreterPoolExecutor, as_completed + check_script_in_subprocess( + textwrap.dedent( + f""" + import contextlib + import gc + from concurrent import interpreters + + subinterpreter = subinterpreters = stack = None + with contextlib.ExitStack() as stack: + subinterpreters = [ + stack.enter_context(contextlib.closing(interpreters.create())) + for _ in range({NUM_FUTURES}) + ] + for subinterpreter in subinterpreters: + subinterpreter.exec('import optree') - def check_import(): import optree - if optree._C.get_registry_size() != 8: - raise RuntimeError('registry size mismatch') + del optree, subinterpreter, subinterpreters, stack + for _ in range(10): + gc.collect() + """, + ).strip(), + output='', + rerun=NUM_FLAKY_RERUNS, + ) - with InterpreterPoolExecutor(max_workers={NUM_WORKERS}) as executor: - futures = [executor.submit(check_import) for _ in range({NUM_FUTURES})] - for future in as_completed(futures): - future.result() - """, - ).strip() - check_script_in_subprocess(script, output='', rerun=NUM_FLAKY_RERUNS) + check_script_in_subprocess( + textwrap.dedent( + f""" + import contextlib + import gc + from concurrent import interpreters + + subinterpreter = subinterpreters = stack = None + with contextlib.ExitStack() as stack: + subinterpreters = [ + stack.enter_context(contextlib.closing(interpreters.create())) + for _ in range({NUM_FUTURES}) + ] + for subinterpreter in subinterpreters: + subinterpreter.exec('import optree') + + import optree + + del optree, subinterpreter, subinterpreters, stack + for _ in range(10): + gc.collect() + """, + ).strip(), + output='', + rerun=NUM_FLAKY_RERUNS, + ) + + +def test_import_in_subinterpreters_concurrently(): + check_script_in_subprocess( + textwrap.dedent( + f""" + from concurrent.futures import InterpreterPoolExecutor, as_completed + + def check_import(): + import optree + + if optree._C.get_registry_size() != 8: + raise RuntimeError('registry size mismatch') + + with InterpreterPoolExecutor(max_workers={NUM_WORKERS}) as executor: + futures = [executor.submit(check_import) for _ in range({NUM_FUTURES})] + for future in as_completed(futures): + future.result() + """, + ).strip(), + output='', + rerun=NUM_FLAKY_RERUNS, + ) diff --git a/tests/helpers.py b/tests/helpers.py index 21e4cbb4..18716464 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -157,21 +157,23 @@ def wrapper(*args, **kwargs): def check_script_in_subprocess(script, /, *, output, env=None, cwd=TEST_ROOT, rerun=1): - if env is None: - env = os.environ - env = { - key: value for key, value in env.items() if not key.startswith(('PYTHON', 'PYTEST', 'COV_')) - } result = '' for _ in range(rerun): - result = subprocess.check_output( - [sys.executable, '-Walways', '-Werror', '-c', script], - stderr=subprocess.STDOUT, - text=True, - encoding='utf-8', - cwd=cwd, - env=env, - ) + try: + result = subprocess.check_output( + [sys.executable, '-Walways', '-Werror', '-c', script], + stderr=subprocess.STDOUT, + text=True, + encoding='utf-8', + cwd=cwd, + env={ + key: value + for key, value in (env if env is not None else os.environ).items() + if not key.startswith(('PYTHON', 'PYTEST', 'COV_')) + }, + ) + except subprocess.CalledProcessError as ex: + raise subprocess.SubprocessError(f'{ex}\nOutput:\n{ex.output}') from None if output is not None: assert result == output return result From ce661a2af046ed2397fe3e2c11567e47401f4c54 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 15 Dec 2025 18:53:58 +0800 Subject: [PATCH 33/59] chore: split ci jobs --- .github/workflows/tests-with-pydebug.yml | 3 ++- include/optree/registry.h | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests-with-pydebug.yml b/.github/workflows/tests-with-pydebug.yml index 1d86e36a..25a606b4 100644 --- a/.github/workflows/tests-with-pydebug.yml +++ b/.github/workflows/tests-with-pydebug.yml @@ -347,7 +347,8 @@ jobs: "--cov-report=xml:coverage-${{ env.PYTHON_TAG }}-${{ runner.os }}.xml" "--junit-xml=junit-${{ env.PYTHON_TAG }}-${{ runner.os }}.xml" ) - make test PYTESTOPTS="${PYTESTOPTS[*]}" + make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'subinterpreter'" + make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'not subinterpreter'" CORE_DUMP_FILES="$( find . -type d -path "./venv" -prune \ diff --git a/include/optree/registry.h b/include/optree/registry.h index ec5dbb39..fa00a50d 100644 --- a/include/optree/registry.h +++ b/include/optree/registry.h @@ -21,7 +21,6 @@ limitations under the License. #include // std::shared_ptr #include // std::optional, std::nullopt #include // std::string -#include // std::tuple #include // std::unordered_map #include // std::unordered_set #include // std::pair From afae64c63e76aefedf7dd27b1a37af69926ac27f Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 15 Dec 2025 20:12:05 +0800 Subject: [PATCH 34/59] fix: fix repr for exception --- tests/concurrent/test_subinterpreters.py | 3 --- tests/concurrent/test_threading.py | 8 +++++++- tests/helpers.py | 7 ++++++- tests/test_treespec.py | 3 +++ 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py index a5668bdd..f48499ff 100644 --- a/tests/concurrent/test_subinterpreters.py +++ b/tests/concurrent/test_subinterpreters.py @@ -81,9 +81,6 @@ def concurrent_run(func, /, *args, **kwargs): return [future.result() for future in completed_futures] -run(object) # warm-up - - def check_module_importable(): import collections import time diff --git a/tests/concurrent/test_threading.py b/tests/concurrent/test_threading.py index 2d03897f..d1888b42 100644 --- a/tests/concurrent/test_threading.py +++ b/tests/concurrent/test_threading.py @@ -355,6 +355,12 @@ def test_tree_iter_thread_safe( namespace=namespace, ) num_leaves = next(counter) + assert optree.tree_leaves( + new_tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) == list(range(num_leaves)) + it = optree.tree_iter( new_tree, none_is_leaf=none_is_leaf, @@ -362,6 +368,6 @@ def test_tree_iter_thread_safe( ) results = concurrent_run(list, it) + assert sorted(itertools.chain.from_iterable(results)) == list(range(num_leaves)) for seq in results: assert sorted(seq) == seq - assert sorted(itertools.chain.from_iterable(results)) == list(range(num_leaves)) diff --git a/tests/helpers.py b/tests/helpers.py index 18716464..bb3bc686 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -156,6 +156,11 @@ def wrapper(*args, **kwargs): return wrapper +class CalledProcessError(subprocess.CalledProcessError): + def __str__(self): + return f'{super().__str__()}\nOutput:\n{self.output}' + + def check_script_in_subprocess(script, /, *, output, env=None, cwd=TEST_ROOT, rerun=1): result = '' for _ in range(rerun): @@ -173,7 +178,7 @@ def check_script_in_subprocess(script, /, *, output, env=None, cwd=TEST_ROOT, re }, ) except subprocess.CalledProcessError as ex: - raise subprocess.SubprocessError(f'{ex}\nOutput:\n{ex.output}') from None + raise CalledProcessError(ex.returncode, ex.cmd, ex.output, ex.stderr) from None if output is not None: assert result == output return result diff --git a/tests/test_treespec.py b/tests/test_treespec.py index 900572e8..5accf368 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -70,6 +70,9 @@ def test_treespec_construct(): treespec = optree.PyTreeSpec.__new__(optree.PyTreeSpec) with pytest.raises(TypeError, match=re.escape('No constructor defined!')): treespec.__init__() + del treespec + + gc_collect() script = textwrap.dedent( r""" From 5fb47690de527dff7e994db18c92dbc3af64cb5f Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 17 Dec 2025 12:44:18 +0800 Subject: [PATCH 35/59] test: skip failed tests --- tests/concurrent/test_subinterpreters.py | 112 ++++++++++++----------- 1 file changed, 59 insertions(+), 53 deletions(-) diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py index f48499ff..420fca49 100644 --- a/tests/concurrent/test_subinterpreters.py +++ b/tests/concurrent/test_subinterpreters.py @@ -216,6 +216,7 @@ def test_import_in_subinterpreter_after_main(): f""" import contextlib import gc + import random from concurrent import interpreters import optree @@ -226,6 +227,7 @@ def test_import_in_subinterpreter_after_main(): stack.enter_context(contextlib.closing(interpreters.create())) for _ in range({NUM_FUTURES}) ] + random.shuffle(subinterpreters) for subinterpreter in subinterpreters: subinterpreter.exec('import optree') @@ -262,59 +264,63 @@ def test_import_in_subinterpreter_before_main(): rerun=NUM_FLAKY_RERUNS, ) - check_script_in_subprocess( - textwrap.dedent( - f""" - import contextlib - import gc - from concurrent import interpreters - - subinterpreter = subinterpreters = stack = None - with contextlib.ExitStack() as stack: - subinterpreters = [ - stack.enter_context(contextlib.closing(interpreters.create())) - for _ in range({NUM_FUTURES}) - ] - for subinterpreter in subinterpreters: - subinterpreter.exec('import optree') - - import optree - - del optree, subinterpreter, subinterpreters, stack - for _ in range(10): - gc.collect() - """, - ).strip(), - output='', - rerun=NUM_FLAKY_RERUNS, - ) - - check_script_in_subprocess( - textwrap.dedent( - f""" - import contextlib - import gc - from concurrent import interpreters - - subinterpreter = subinterpreters = stack = None - with contextlib.ExitStack() as stack: - subinterpreters = [ - stack.enter_context(contextlib.closing(interpreters.create())) - for _ in range({NUM_FUTURES}) - ] - for subinterpreter in subinterpreters: - subinterpreter.exec('import optree') - - import optree - - del optree, subinterpreter, subinterpreters, stack - for _ in range(10): - gc.collect() - """, - ).strip(), - output='', - rerun=NUM_FLAKY_RERUNS, - ) + # check_script_in_subprocess( + # textwrap.dedent( + # f""" + # import contextlib + # import gc + # import random + # from concurrent import interpreters + + # subinterpreter = subinterpreters = stack = None + # with contextlib.ExitStack() as stack: + # subinterpreters = [ + # stack.enter_context(contextlib.closing(interpreters.create())) + # for _ in range({NUM_FUTURES}) + # ] + # random.shuffle(subinterpreters) + # for subinterpreter in subinterpreters: + # subinterpreter.exec('import optree') + + # import optree + + # del optree, subinterpreter, subinterpreters, stack + # for _ in range(10): + # gc.collect() + # """, + # ).strip(), + # output='', + # rerun=NUM_FLAKY_RERUNS, + # ) + + # check_script_in_subprocess( + # textwrap.dedent( + # f""" + # import contextlib + # import gc + # import random + # from concurrent import interpreters + + # subinterpreter = subinterpreters = stack = None + # with contextlib.ExitStack() as stack: + # subinterpreters = [ + # stack.enter_context(contextlib.closing(interpreters.create())) + # for _ in range({NUM_FUTURES}) + # ] + # random.shuffle(subinterpreters) + # for subinterpreter in subinterpreters: + # subinterpreter.exec('import optree') + + # import optree + + # del optree, subinterpreter, subinterpreters, stack + # for _ in range(10): + # gc.collect() + # """, + # ).strip(), + # output='', + # rerun=NUM_FLAKY_RERUNS, + # ) def test_import_in_subinterpreters_concurrently(): From e486981794e6c3687324fa1576203d07bf74dd37 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 17 Dec 2025 22:09:56 +0800 Subject: [PATCH 36/59] test: set no-cov for subinterpreter tests --- .github/workflows/tests-with-pydebug.yml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests-with-pydebug.yml b/.github/workflows/tests-with-pydebug.yml index 1af269fb..c32b4ea6 100644 --- a/.github/workflows/tests-with-pydebug.yml +++ b/.github/workflows/tests-with-pydebug.yml @@ -347,8 +347,13 @@ jobs: "--cov-report=xml:coverage-${{ env.PYTHON_TAG }}-${{ runner.os }}.xml" "--junit-xml=junit-${{ env.PYTHON_TAG }}-${{ runner.os }}.xml" ) - make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'subinterpreter'" - make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'not subinterpreter'" + + if ${{ env.PYTHON }} -c 'import sys, optree; sys.exit(not optree._C.OPTREE_HAS_SUBINTERPRETER_SUPPORT)'; then + make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'subinterpreter' --no-cov" + make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'not subinterpreter'" + else + make test PYTESTOPTS="${PYTESTOPTS[*]}" + fi CORE_DUMP_FILES="$( find . -type d -path "./venv" -prune \ From d89e120c3d877696e4aa17016aa1ff5a8eff1d59 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 18 Dec 2025 10:17:46 +0800 Subject: [PATCH 37/59] test: set env for subprocess --- tests/confcoverage.py | 5 ++++- tests/helpers.py | 7 +++++-- tests/test_treespec.py | 14 ++++++++------ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/tests/confcoverage.py b/tests/confcoverage.py index 7997efd0..5635dee2 100644 --- a/tests/confcoverage.py +++ b/tests/confcoverage.py @@ -38,7 +38,10 @@ def is_importable(mod: str) -> bool: env = { key: value for key, value in os.environ.items() - if not key.startswith(('PYTHON', 'PYTEST', 'COV_')) + if ( + not key.startswith(('PYTHON', 'PYTEST', 'COV_')) + or key in ('PYTHON_GIL', 'PYTHONDEVMODE', 'PYTHONHASHSEED') + ) } try: subprocess.check_call( diff --git a/tests/helpers.py b/tests/helpers.py index f23978b9..727f2895 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -172,7 +172,7 @@ def check_script_in_subprocess(script, /, *, output, env=None, cwd=TEST_ROOT, re for _ in range(rerun): try: result = subprocess.check_output( - [sys.executable, '-Walways', '-Werror', '-c', script], + [sys.executable, '-u', '-X', 'dev', '-Walways', '-Werror', '-c', script], stderr=subprocess.STDOUT, text=True, encoding='utf-8', @@ -180,7 +180,10 @@ def check_script_in_subprocess(script, /, *, output, env=None, cwd=TEST_ROOT, re env={ key: value for key, value in (env if env is not None else os.environ).items() - if not key.startswith(('PYTHON', 'PYTEST', 'COV_')) + if ( + not key.startswith(('PYTHON', 'PYTEST', 'COV_')) + or key in ('PYTHON_GIL', 'PYTHONDEVMODE', 'PYTHONHASHSEED') + ) }, ) except subprocess.CalledProcessError as ex: diff --git a/tests/test_treespec.py b/tests/test_treespec.py index 5accf368..d63972ec 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -539,11 +539,6 @@ def test_treespec_pickle_missing_registration(): treespec = optree.tree_structure(Foo(0, 1), namespace='foo') serialized = pickle.dumps(treespec) - env = { - key: value - for key, value in os.environ.items() - if not key.startswith(('PYTHON', 'PYTEST', 'COV_')) - } try: output = subprocess.run( [ @@ -571,7 +566,14 @@ def test_treespec_pickle_missing_registration(): text=True, encoding='utf-8', cwd=TEST_ROOT, - env=env, + env={ + key: value + for key, value in os.environ.items() + if ( + not key.startswith(('PYTHON', 'PYTEST', 'COV_')) + or key in ('PYTHON_GIL', 'PYTHONDEVMODE', 'PYTHONHASHSEED') + ) + }, ) message = output.stdout.strip() except subprocess.CalledProcessError as ex: From 5efad7e25bff12aeb98dc5f1018c8bff3d4ee6e7 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 18 Dec 2025 21:39:43 +0800 Subject: [PATCH 38/59] chore: split tests --- .github/workflows/tests-with-pydebug.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests-with-pydebug.yml b/.github/workflows/tests-with-pydebug.yml index c32b4ea6..aec5148c 100644 --- a/.github/workflows/tests-with-pydebug.yml +++ b/.github/workflows/tests-with-pydebug.yml @@ -350,6 +350,7 @@ jobs: if ${{ env.PYTHON }} -c 'import sys, optree; sys.exit(not optree._C.OPTREE_HAS_SUBINTERPRETER_SUPPORT)'; then make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'subinterpreter' --no-cov" + make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'concurrent' --no-cov" make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'not subinterpreter'" else make test PYTESTOPTS="${PYTESTOPTS[*]}" From bb0a4d91a6921920e19bbd7fc2293f21b520ab46 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 19 Dec 2025 18:52:00 +0800 Subject: [PATCH 39/59] chore(pre-commit): update pre-commit hooks --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5102a176..76f25677 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: args: [--ignore-case] files: ^docs/source/spelling_wordlist\.txt$ - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v21.1.7 + rev: v21.1.8 hooks: - id: clang-format - repo: https://github.com/cpplint/cpplint @@ -38,7 +38,7 @@ repos: hooks: - id: cpplint - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.9 + rev: v0.14.10 hooks: - id: ruff-check args: [--fix, --exit-non-zero-on-fix] From b7cb1cf63aa856e39975b1fb2b3f4a94dcf393e5 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 19 Dec 2025 18:53:04 +0800 Subject: [PATCH 40/59] test: enable subinterpreter tests --- tests/concurrent/test_subinterpreters.py | 235 +++++++++++------------ tests/helpers.py | 2 + tests/test_treespec.py | 37 ++-- 3 files changed, 132 insertions(+), 142 deletions(-) diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py index 420fca49..93a15550 100644 --- a/tests/concurrent/test_subinterpreters.py +++ b/tests/concurrent/test_subinterpreters.py @@ -17,7 +17,6 @@ import contextlib import random import sys -import textwrap import pytest @@ -190,157 +189,145 @@ def test_import(): def test_import_in_subinterpreter_after_main(): check_script_in_subprocess( - textwrap.dedent( - """ - import contextlib - import gc - from concurrent import interpreters + """ + import contextlib + import gc + from concurrent import interpreters - import optree + import optree + + subinterpreter = None + with contextlib.closing(interpreters.create()) as subinterpreter: + subinterpreter.exec('import optree') + + del optree, subinterpreter + for _ in range(10): + gc.collect() + """, + output='', + rerun=NUM_FLAKY_RERUNS, + ) - subinterpreter = None - with contextlib.closing(interpreters.create()) as subinterpreter: + check_script_in_subprocess( + f""" + import contextlib + import gc + import random + from concurrent import interpreters + + import optree + + subinterpreter = subinterpreters = stack = None + with contextlib.ExitStack() as stack: + subinterpreters = [ + stack.enter_context(contextlib.closing(interpreters.create())) + for _ in range({NUM_FUTURES}) + ] + random.shuffle(subinterpreters) + for subinterpreter in subinterpreters: subinterpreter.exec('import optree') - del optree, subinterpreter - for _ in range(10): - gc.collect() - """, - ).strip(), + del optree, subinterpreter, subinterpreters, stack + for _ in range(10): + gc.collect() + """, output='', rerun=NUM_FLAKY_RERUNS, ) + +def test_import_in_subinterpreter_before_main(): check_script_in_subprocess( - textwrap.dedent( - f""" - import contextlib - import gc - import random - from concurrent import interpreters + """ + import contextlib + import gc + from concurrent import interpreters - import optree + subinterpreter = None + with contextlib.closing(interpreters.create()) as subinterpreter: + subinterpreter.exec('import optree') - subinterpreter = subinterpreters = stack = None - with contextlib.ExitStack() as stack: - subinterpreters = [ - stack.enter_context(contextlib.closing(interpreters.create())) - for _ in range({NUM_FUTURES}) - ] - random.shuffle(subinterpreters) - for subinterpreter in subinterpreters: - subinterpreter.exec('import optree') - - del optree, subinterpreter, subinterpreters, stack - for _ in range(10): - gc.collect() - """, - ).strip(), + import optree + + del optree, subinterpreter + for _ in range(10): + gc.collect() + """, output='', rerun=NUM_FLAKY_RERUNS, ) + check_script_in_subprocess( + f""" + import contextlib + import gc + import random + from concurrent import interpreters + + subinterpreter = subinterpreters = stack = None + with contextlib.ExitStack() as stack: + subinterpreters = [ + stack.enter_context(contextlib.closing(interpreters.create())) + for _ in range({NUM_FUTURES}) + ] + random.shuffle(subinterpreters) + for subinterpreter in subinterpreters: + subinterpreter.exec('import optree') + + import optree + + del optree, subinterpreter, subinterpreters, stack + for _ in range(10): + gc.collect() + """, + output='', + rerun=NUM_FLAKY_RERUNS, + ) -def test_import_in_subinterpreter_before_main(): check_script_in_subprocess( - textwrap.dedent( - """ - import contextlib - import gc - from concurrent import interpreters - - subinterpreter = None - with contextlib.closing(interpreters.create()) as subinterpreter: + f""" + import contextlib + import gc + import random + from concurrent import interpreters + + subinterpreter = subinterpreters = stack = None + with contextlib.ExitStack() as stack: + subinterpreters = [ + stack.enter_context(contextlib.closing(interpreters.create())) + for _ in range({NUM_FUTURES}) + ] + random.shuffle(subinterpreters) + for subinterpreter in subinterpreters: subinterpreter.exec('import optree') import optree - del optree, subinterpreter - for _ in range(10): - gc.collect() - """, - ).strip(), + del optree, subinterpreter, subinterpreters, stack + for _ in range(10): + gc.collect() + """, output='', rerun=NUM_FLAKY_RERUNS, ) - # check_script_in_subprocess( - # textwrap.dedent( - # f""" - # import contextlib - # import gc - # import random - # from concurrent import interpreters - - # subinterpreter = subinterpreters = stack = None - # with contextlib.ExitStack() as stack: - # subinterpreters = [ - # stack.enter_context(contextlib.closing(interpreters.create())) - # for _ in range({NUM_FUTURES}) - # ] - # random.shuffle(subinterpreters) - # for subinterpreter in subinterpreters: - # subinterpreter.exec('import optree') - - # import optree - - # del optree, subinterpreter, subinterpreters, stack - # for _ in range(10): - # gc.collect() - # """, - # ).strip(), - # output='', - # rerun=NUM_FLAKY_RERUNS, - # ) - - # check_script_in_subprocess( - # textwrap.dedent( - # f""" - # import contextlib - # import gc - # import random - # from concurrent import interpreters - - # subinterpreter = subinterpreters = stack = None - # with contextlib.ExitStack() as stack: - # subinterpreters = [ - # stack.enter_context(contextlib.closing(interpreters.create())) - # for _ in range({NUM_FUTURES}) - # ] - # random.shuffle(subinterpreters) - # for subinterpreter in subinterpreters: - # subinterpreter.exec('import optree') - - # import optree - - # del optree, subinterpreter, subinterpreters, stack - # for _ in range(10): - # gc.collect() - # """, - # ).strip(), - # output='', - # rerun=NUM_FLAKY_RERUNS, - # ) - def test_import_in_subinterpreters_concurrently(): check_script_in_subprocess( - textwrap.dedent( - f""" - from concurrent.futures import InterpreterPoolExecutor, as_completed - - def check_import(): - import optree - - if optree._C.get_registry_size() != 8: - raise RuntimeError('registry size mismatch') - - with InterpreterPoolExecutor(max_workers={NUM_WORKERS}) as executor: - futures = [executor.submit(check_import) for _ in range({NUM_FUTURES})] - for future in as_completed(futures): - future.result() - """, - ).strip(), + f""" + from concurrent.futures import InterpreterPoolExecutor, as_completed + + def check_import(): + import optree + + if optree._C.get_registry_size() != 8: + raise RuntimeError('registry size mismatch') + + with InterpreterPoolExecutor(max_workers={NUM_WORKERS}) as executor: + futures = [executor.submit(check_import) for _ in range({NUM_FUTURES})] + for future in as_completed(futures): + future.result() + """, output='', rerun=NUM_FLAKY_RERUNS, ) diff --git a/tests/helpers.py b/tests/helpers.py index 727f2895..ff05a4d3 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -25,6 +25,7 @@ import subprocess import sys import sysconfig +import textwrap import time import types from collections import OrderedDict, UserDict, defaultdict, deque, namedtuple @@ -168,6 +169,7 @@ def __str__(self): def check_script_in_subprocess(script, /, *, output, env=None, cwd=TEST_ROOT, rerun=1): + script = textwrap.dedent(script).strip() result = '' for _ in range(rerun): try: diff --git a/tests/test_treespec.py b/tests/test_treespec.py index d63972ec..e6aeb165 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -74,27 +74,28 @@ def test_treespec_construct(): gc_collect() - script = textwrap.dedent( - r""" - import signal - import sys - - import optree - import optree._C - - for _ in range(32): - treespec = optree.PyTreeSpec.__new__(optree.PyTreeSpec) - try: - repr(treespec) - except optree._C.InternalError as ex: - assert 'src/treespec/serialization.cpp' in str(ex).replace('\\', '/') - sys.exit(0) - """, - ).strip() returncode = 0 try: with tempfile.TemporaryDirectory() as tmpdir: - check_script_in_subprocess(script, cwd=tmpdir, output=None) + check_script_in_subprocess( + r""" + import signal + import sys + + import optree + import optree._C + + for _ in range(32): + treespec = optree.PyTreeSpec.__new__(optree.PyTreeSpec) + try: + repr(treespec) + except optree._C.InternalError as ex: + assert 'src/treespec/serialization.cpp' in str(ex).replace('\\', '/') + sys.exit(0) + """, + cwd=tmpdir, + output=None, + ) except subprocess.CalledProcessError as ex: returncode = abs(ex.returncode) if 128 < returncode < 256: From f7abc85675aee19ea3c101122168d834e0fefd05 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 25 Dec 2025 14:13:04 +0800 Subject: [PATCH 41/59] chore: update nightly remote --- .github/workflows/set_setup_requires.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/set_setup_requires.py b/.github/workflows/set_setup_requires.py index 0e442ea1..1d6c3583 100755 --- a/.github/workflows/set_setup_requires.py +++ b/.github/workflows/set_setup_requires.py @@ -9,7 +9,7 @@ ROOT = Path(__file__).absolute().parents[2] PYPROJECT_FILE = ROOT / 'pyproject.toml' -PYBIND11_GIT_URL = 'https://github.com/XuehaiPan/pybind11.git@subinterp-call-once-and-store' +PYBIND11_GIT_URL = 'https://github.com/XuehaiPan/pybind11.git@fix-multiple-interpreters-concurrency' if __name__ == '__main__': From ea8cc2c97c116fd9cee1c7e10d3432685b46faad Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 25 Dec 2025 23:00:01 +0800 Subject: [PATCH 42/59] chore: add more build time meta --- optree/_C.pyi | 2 ++ src/optree.cpp | 2 ++ tests/helpers.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/optree/_C.pyi b/optree/_C.pyi index 61932d33..35e73649 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -49,6 +49,8 @@ Py_DEBUG: Final[bool] Py_GIL_DISABLED: Final[bool] PYBIND11_VERSION_HEX: Final[int] PYBIND11_INTERNALS_VERSION: Final[int] +PYBIND11_INTERNALS_ID: Final[str] +PYBIND11_MODULE_LOCAL_ID: Final[str] PYBIND11_HAS_NATIVE_ENUM: Final[bool] PYBIND11_HAS_INTERNALS_WITH_SMART_HOLDER_SUPPORT: Final[bool] PYBIND11_HAS_SUBINTERPRETER_SUPPORT: Final[bool] diff --git a/src/optree.cpp b/src/optree.cpp index 09659ba8..37e6d9bd 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -73,6 +73,8 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] #endif BUILDTIME_METADATA["PYBIND11_VERSION_HEX"] = py::int_(PYBIND11_VERSION_HEX); BUILDTIME_METADATA["PYBIND11_INTERNALS_VERSION"] = py::int_(PYBIND11_INTERNALS_VERSION); + BUILDTIME_METADATA["PYBIND11_INTERNALS_ID"] = py::str(PYBIND11_INTERNALS_ID); + BUILDTIME_METADATA["PYBIND11_MODULE_LOCAL_ID"] = py::str(PYBIND11_MODULE_LOCAL_ID); #if defined(PYBIND11_HAS_NATIVE_ENUM) && NONZERO_OR_EMPTY(PYBIND11_HAS_NATIVE_ENUM) BUILDTIME_METADATA["PYBIND11_HAS_NATIVE_ENUM"] = py::bool_(true); #else diff --git a/tests/helpers.py b/tests/helpers.py index ff05a4d3..bf37d176 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -38,6 +38,7 @@ from optree._C import ( OPTREE_HAS_SUBINTERPRETER_SUPPORT, PYBIND11_HAS_NATIVE_ENUM, + PYBIND11_HAS_SUBINTERPRETER_SUPPORT, Py_DEBUG, Py_GIL_DISABLED, get_registry_size, @@ -54,6 +55,7 @@ assert INITIAL_REGISTRY_SIZE + 2 == len(NODETYPE_REGISTRY) _ = PYBIND11_HAS_NATIVE_ENUM +_ = PYBIND11_HAS_SUBINTERPRETER_SUPPORT _ = OPTREE_HAS_SUBINTERPRETER_SUPPORT if sysconfig.get_config_var('Py_DEBUG') is None: From 2935452e04cb8a23a6ad84fe54962b6ce261d963 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 27 Dec 2025 04:07:57 +0800 Subject: [PATCH 43/59] chore: update nightly remote --- .github/workflows/set_setup_requires.py | 2 +- .github/workflows/tests-with-pydebug.yml | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/set_setup_requires.py b/.github/workflows/set_setup_requires.py index 1d6c3583..df07c011 100755 --- a/.github/workflows/set_setup_requires.py +++ b/.github/workflows/set_setup_requires.py @@ -9,7 +9,7 @@ ROOT = Path(__file__).absolute().parents[2] PYPROJECT_FILE = ROOT / 'pyproject.toml' -PYBIND11_GIT_URL = 'https://github.com/XuehaiPan/pybind11.git@fix-multiple-interpreters-concurrency' +PYBIND11_GIT_URL = 'https://github.com/pybind/pybind11.git' if __name__ == '__main__': diff --git a/.github/workflows/tests-with-pydebug.yml b/.github/workflows/tests-with-pydebug.yml index aec5148c..2573e64e 100644 --- a/.github/workflows/tests-with-pydebug.yml +++ b/.github/workflows/tests-with-pydebug.yml @@ -349,7 +349,6 @@ jobs: ) if ${{ env.PYTHON }} -c 'import sys, optree; sys.exit(not optree._C.OPTREE_HAS_SUBINTERPRETER_SUPPORT)'; then - make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'subinterpreter' --no-cov" make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'concurrent' --no-cov" make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'not subinterpreter'" else From fb08dec746e6ffffe0723eb2d637e87387960271 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 27 Dec 2025 11:51:08 +0800 Subject: [PATCH 44/59] chore: update test --- tests/concurrent/test_threading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/concurrent/test_threading.py b/tests/concurrent/test_threading.py index d1888b42..49857cab 100644 --- a/tests/concurrent/test_threading.py +++ b/tests/concurrent/test_threading.py @@ -367,7 +367,7 @@ def test_tree_iter_thread_safe( namespace=namespace, ) - results = concurrent_run(list, it) + results = concurrent_run(lambda x: list(x), it) assert sorted(itertools.chain.from_iterable(results)) == list(range(num_leaves)) for seq in results: assert sorted(seq) == seq From 0d18d19088d4e5640c0067bcd0ab4fb60339b6f9 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 27 Dec 2025 17:24:28 +0800 Subject: [PATCH 45/59] chore: update test --- include/optree/pymacros.h | 33 +++++++----------------- optree/_C.pyi | 1 + src/optree.cpp | 3 +++ tests/concurrent/test_subinterpreters.py | 16 ++++++++++-- tests/concurrent/test_threading.py | 2 +- 5 files changed, 29 insertions(+), 26 deletions(-) diff --git a/include/optree/pymacros.h b/include/optree/pymacros.h index c8063bcb..d9eed80b 100644 --- a/include/optree/pymacros.h +++ b/include/optree/pymacros.h @@ -17,7 +17,6 @@ limitations under the License. #pragma once -#include // std::int64_t #include // std::runtime_error #include @@ -110,20 +109,22 @@ Py_Declare_ID(n_fields); // structseq.n_fields Py_Declare_ID(n_sequence_fields); // structseq.n_sequence_fields Py_Declare_ID(n_unnamed_fields); // structseq.n_unnamed_fields -using interpid_t = std::int64_t; +using interpid_t = decltype(PyInterpreterState_GetID(nullptr)); -#if defined(OPTREE_HAS_SUBINTERPRETER_SUPPORT) +[[nodiscard]] inline bool IsCurrentPyInterpreterMain() { + return PyInterpreterState_Get() == PyInterpreterState_Main(); +} [[nodiscard]] inline interpid_t GetCurrentPyInterpreterID() { PyInterpreterState *interp = PyInterpreterState_Get(); - if (PyErr_Occurred() != nullptr) [[unlikely]] { + if (PyErr_Occurred()) [[unlikely]] { throw py::error_already_set(); } - if (interp == nullptr) [[unlikely]] { + if (!interp) [[unlikely]] { throw std::runtime_error("Failed to get the current Python interpreter state."); } const interpid_t interpid = PyInterpreterState_GetID(interp); - if (PyErr_Occurred() != nullptr) [[unlikely]] { + if (PyErr_Occurred()) [[unlikely]] { throw py::error_already_set(); } return interpid; @@ -131,29 +132,15 @@ using interpid_t = std::int64_t; [[nodiscard]] inline interpid_t GetMainPyInterpreterID() { PyInterpreterState *interp = PyInterpreterState_Main(); - if (PyErr_Occurred() != nullptr) [[unlikely]] { + if (PyErr_Occurred()) [[unlikely]] { throw py::error_already_set(); } - if (interp == nullptr) [[unlikely]] { + if (!interp) [[unlikely]] { throw std::runtime_error("Failed to get the main Python interpreter state."); } const interpid_t interpid = PyInterpreterState_GetID(interp); - if (PyErr_Occurred() != nullptr) [[unlikely]] { + if (PyErr_Occurred()) [[unlikely]] { throw py::error_already_set(); } return interpid; } - -#else - -[[nodiscard]] inline constexpr interpid_t GetCurrentPyInterpreterID() noexcept { - // Fallback for Python versions < 3.14 or when subinterpreter support is not available. - return 0; -} - -[[nodiscard]] inline constexpr interpid_t GetMainPyInterpreterID() noexcept { - // Fallback for Python versions < 3.14 or when subinterpreter support is not available. - return 0; -} - -#endif diff --git a/optree/_C.pyi b/optree/_C.pyi index 35e73649..f791c9e7 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -220,5 +220,6 @@ def get_registry_size(namespace: str | None = None) -> int: ... def get_num_interpreters_seen() -> int: ... def get_num_interpreters_alive() -> int: ... def get_alive_interpreter_ids() -> set[int]: ... +def is_current_interpreter_main() -> bool: ... def get_current_interpreter_id() -> int: ... def get_main_interpreter_id() -> int: ... diff --git a/src/optree.cpp b/src/optree.cpp index 37e6d9bd..bec618fa 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -169,6 +169,9 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] .def("get_alive_interpreter_ids", &PyTreeTypeRegistry::GetAliveInterpreterIDs, "Get the IDs of alive interpreters that have seen the registry.") + .def("is_current_interpreter_main", + &IsCurrentPyInterpreterMain, + "Check whether the current interpreter is the main interpreter.") .def("get_current_interpreter_id", &GetCurrentPyInterpreterID, "Get the ID of the current interpreter.") diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py index 93a15550..41c9dc07 100644 --- a/tests/concurrent/test_subinterpreters.py +++ b/tests/concurrent/test_subinterpreters.py @@ -86,7 +86,14 @@ def check_module_importable(): import optree._C - if optree._C.get_registry_size() != 8: + is_current_interpreter_main = optree._C.is_current_interpreter_main() + main_interpreter_id = optree._C.get_main_interpreter_id() + current_interpreter_id = optree._C.get_current_interpreter_id() + + if is_current_interpreter_main != (main_interpreter_id == current_interpreter_id): + raise RuntimeError('interpreter identity mismatch') + + if not is_current_interpreter_main and optree._C.get_registry_size() != 8: raise RuntimeError('registry size mismatch') tree = { @@ -136,7 +143,8 @@ def check_module_importable(): _ = optree.tree_flatten_with_accessor(tree, none_is_leaf=True) return ( - optree._C.get_main_interpreter_id(), + is_current_interpreter_main, + main_interpreter_id, id(type(None)), id(tuple), id(list), @@ -149,6 +157,7 @@ def test_import(): import collections expected = ( + False, 0, id(type(None)), id(tuple), @@ -157,6 +166,7 @@ def test_import(): id(collections.OrderedDict), ) + assert check_module_importable() == (True, *expected[1:]) assert run(check_module_importable) == expected for _ in range(random.randint(5, 10)): @@ -322,6 +332,8 @@ def check_import(): if optree._C.get_registry_size() != 8: raise RuntimeError('registry size mismatch') + if optree._C.is_current_interpreter_main(): + raise RuntimeError('expected subinterpreter') with InterpreterPoolExecutor(max_workers={NUM_WORKERS}) as executor: futures = [executor.submit(check_import) for _ in range({NUM_FUTURES})] diff --git a/tests/concurrent/test_threading.py b/tests/concurrent/test_threading.py index 49857cab..d1888b42 100644 --- a/tests/concurrent/test_threading.py +++ b/tests/concurrent/test_threading.py @@ -367,7 +367,7 @@ def test_tree_iter_thread_safe( namespace=namespace, ) - results = concurrent_run(lambda x: list(x), it) + results = concurrent_run(list, it) assert sorted(itertools.chain.from_iterable(results)) == list(range(num_leaves)) for seq in results: assert sorted(seq) == seq From 3996bb081ccf24036ab65e1e06fc606ad16dd6df Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 27 Dec 2025 17:31:42 +0800 Subject: [PATCH 46/59] chore: update macros --- include/optree/pymacros.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/optree/pymacros.h b/include/optree/pymacros.h index d9eed80b..068d8c4c 100644 --- a/include/optree/pymacros.h +++ b/include/optree/pymacros.h @@ -34,8 +34,8 @@ limitations under the License. // NOLINTNEXTLINE[bugprone-macro-parentheses] #define NONZERO_OR_EMPTY(MACRO) ((MACRO + 0 != 0) || (0 - MACRO - 1 >= 0)) -#if !defined(PYPY_VERSION) && \ - (defined(PY_VERSION_HEX) && PY_VERSION_HEX >= 0x030E0000 /* Python 3.14 */) && \ +#if !defined(PYPY_VERSION) && (PY_VERSION_HEX >= 0x030E0000 /* Python 3.14 */) && \ + (PYBIND11_VERSION_HEX >= 0x030002A0 /* pybind11 3.0.2.a0 */) && \ (defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \ NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT)) # define OPTREE_HAS_SUBINTERPRETER_SUPPORT 1 From c420862bbc7788e367affcef886accfa34a52e98 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 27 Dec 2025 19:50:53 +0800 Subject: [PATCH 47/59] chore: remove `Py_Get_ID` --- include/optree/pymacros.h | 50 ++++------------------------------ include/optree/pytypes.h | 27 +++++++++--------- src/optree.cpp | 6 ++-- src/treespec/constructors.cpp | 9 +++--- src/treespec/flatten.cpp | 18 ++++++------ src/treespec/serialization.cpp | 15 +++++----- src/treespec/treespec.cpp | 4 +-- 7 files changed, 43 insertions(+), 86 deletions(-) diff --git a/include/optree/pymacros.h b/include/optree/pymacros.h index 068d8c4c..41707ff3 100644 --- a/include/optree/pymacros.h +++ b/include/optree/pymacros.h @@ -71,44 +71,6 @@ inline constexpr Py_ALWAYS_INLINE bool Py_IsConstant(PyObject *x) noexcept { } #define Py_IsConstant(x) Py_IsConstant(x) -#define Py_Declare_ID(name) \ - inline namespace { \ - [[nodiscard]] inline PyObject *Py_ID_##name() { \ - PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store storage; \ - return storage \ - .call_once_and_store_result([]() -> PyObject * { \ - PyObject * const ptr = PyUnicode_InternFromString(#name); \ - if (ptr == nullptr) [[unlikely]] { \ - throw py::error_already_set(); \ - } \ - Py_INCREF(ptr); /* leak a reference on purpose */ \ - return ptr; \ - }) \ - .get_stored(); \ - } \ - } // namespace - -#define Py_Get_ID(name) (::Py_ID_##name()) - -Py_Declare_ID(optree); -Py_Declare_ID(__main__); // __main__ -Py_Declare_ID(__module__); // type.__module__ -Py_Declare_ID(__qualname__); // type.__qualname__ -Py_Declare_ID(__name__); // type.__name__ -Py_Declare_ID(sort); // list.sort -Py_Declare_ID(copy); // dict.copy -Py_Declare_ID(OrderedDict); // OrderedDict -Py_Declare_ID(defaultdict); // defaultdict -Py_Declare_ID(deque); // deque -Py_Declare_ID(default_factory); // defaultdict.default_factory -Py_Declare_ID(maxlen); // deque.maxlen -Py_Declare_ID(_fields); // namedtuple._fields -Py_Declare_ID(_make); // namedtuple._make -Py_Declare_ID(_asdict); // namedtuple._asdict -Py_Declare_ID(n_fields); // structseq.n_fields -Py_Declare_ID(n_sequence_fields); // structseq.n_sequence_fields -Py_Declare_ID(n_unnamed_fields); // structseq.n_unnamed_fields - using interpid_t = decltype(PyInterpreterState_GetID(nullptr)); [[nodiscard]] inline bool IsCurrentPyInterpreterMain() { @@ -117,14 +79,14 @@ using interpid_t = decltype(PyInterpreterState_GetID(nullptr)); [[nodiscard]] inline interpid_t GetCurrentPyInterpreterID() { PyInterpreterState *interp = PyInterpreterState_Get(); - if (PyErr_Occurred()) [[unlikely]] { + if (PyErr_Occurred() != nullptr) [[unlikely]] { throw py::error_already_set(); } - if (!interp) [[unlikely]] { + if (interp == nullptr) [[unlikely]] { throw std::runtime_error("Failed to get the current Python interpreter state."); } const interpid_t interpid = PyInterpreterState_GetID(interp); - if (PyErr_Occurred()) [[unlikely]] { + if (PyErr_Occurred() != nullptr) [[unlikely]] { throw py::error_already_set(); } return interpid; @@ -132,14 +94,14 @@ using interpid_t = decltype(PyInterpreterState_GetID(nullptr)); [[nodiscard]] inline interpid_t GetMainPyInterpreterID() { PyInterpreterState *interp = PyInterpreterState_Main(); - if (PyErr_Occurred()) [[unlikely]] { + if (PyErr_Occurred() != nullptr) [[unlikely]] { throw py::error_already_set(); } - if (!interp) [[unlikely]] { + if (interp == nullptr) [[unlikely]] { throw std::runtime_error("Failed to get the main Python interpreter state."); } const interpid_t interpid = PyInterpreterState_GetID(interp); - if (PyErr_Occurred()) [[unlikely]] { + if (PyErr_Occurred() != nullptr) [[unlikely]] { throw py::error_already_set(); } return interpid; diff --git a/include/optree/pytypes.h b/include/optree/pytypes.h index b031929b..a932bd9e 100644 --- a/include/optree/pytypes.h +++ b/include/optree/pytypes.h @@ -218,8 +218,7 @@ inline bool IsNamedTupleClassImpl(const py::handle &type) { // We can only identify namedtuples heuristically, here by the presence of a _fields attribute. if (PyType_FastSubclass(reinterpret_cast(type.ptr()), Py_TPFLAGS_TUPLE_SUBCLASS)) [[unlikely]] { - if (PyObject * const _fields = PyObject_GetAttr(type.ptr(), Py_Get_ID(_fields))) - [[unlikely]] { + if (PyObject * const _fields = PyObject_GetAttrString(type.ptr(), "_fields")) [[unlikely]] { bool fields_ok = static_cast(PyTuple_CheckExact(_fields)); if (fields_ok) [[likely]] { for (const auto &field : py::reinterpret_borrow(_fields)) { @@ -232,8 +231,9 @@ inline bool IsNamedTupleClassImpl(const py::handle &type) { Py_DECREF(_fields); if (fields_ok) [[likely]] { // NOLINTNEXTLINE[readability-use-anyofallof] - for (PyObject * const name : {Py_Get_ID(_make), Py_Get_ID(_asdict)}) { - if (PyObject * const attr = PyObject_GetAttr(type.ptr(), name)) [[likely]] { + for (const char * const name : {"_make", "_asdict"}) { + if (PyObject * const attr = PyObject_GetAttrString(type.ptr(), name)) + [[likely]] { const bool result = static_cast(PyCallable_Check(attr)); Py_DECREF(attr); if (!result) [[unlikely]] { @@ -311,7 +311,7 @@ inline py::tuple NamedTupleGetFields(const py::handle &object) { PyRepr(object) + "."); } } - return EVALUATE_WITH_LOCK_HELD(py::getattr(type, Py_Get_ID(_fields)), type); + return EVALUATE_WITH_LOCK_HELD(py::getattr(type, "_fields"), type); } inline bool IsStructSequenceClassImpl(const py::handle &type) { @@ -325,9 +325,8 @@ inline bool IsStructSequenceClassImpl(const py::handle &type) { PyTuple_GET_ITEM(type_object->tp_bases, 0) == reinterpret_cast(&PyTuple_Type)) [[unlikely]] { // NOLINTNEXTLINE[readability-use-anyofallof] - for (PyObject * const name : - {Py_Get_ID(n_fields), Py_Get_ID(n_sequence_fields), Py_Get_ID(n_unnamed_fields)}) { - if (PyObject * const attr = PyObject_GetAttr(type.ptr(), name)) [[unlikely]] { + for (const char * const name : {"n_fields", "n_sequence_fields", "n_unnamed_fields"}) { + if (PyObject * const attr = PyObject_GetAttrString(type.ptr(), name)) [[unlikely]] { const bool result = static_cast(PyLong_CheckExact(attr)); Py_DECREF(attr); if (!result) [[unlikely]] { @@ -418,7 +417,7 @@ inline py::tuple StructSequenceGetFieldsImpl(const py::handle &type) { return py::tuple{fields}; #else const auto n_sequence_fields = thread_safe_cast( - EVALUATE_WITH_LOCK_HELD(py::getattr(type, Py_Get_ID(n_sequence_fields)), type)); + EVALUATE_WITH_LOCK_HELD(py::getattr(type, "n_sequence_fields"), type)); const auto * const members = reinterpret_cast(type.ptr())->tp_members; py::tuple fields{n_sequence_fields}; for (py::ssize_t i = 0; i < n_sequence_fields; ++i) { @@ -489,15 +488,15 @@ inline void TotalOrderSort(py::list &list) { // NOLINT[runtime/references] // Sort with `(f'{obj.__class__.__module__}.{obj.__class__.__qualname__}', obj)` const auto sort_key_fn = py::cpp_function([](const py::object &obj) -> py::tuple { const py::handle cls = py::type::handle_of(obj); - const py::str qualname{EVALUATE_WITH_LOCK_HELD( - PyStr(py::getattr(cls, Py_Get_ID(__module__))) + "." + - PyStr(py::getattr(cls, Py_Get_ID(__qualname__))), - cls)}; + const py::str qualname{ + EVALUATE_WITH_LOCK_HELD(PyStr(py::getattr(cls, "__module__")) + "." + + PyStr(py::getattr(cls, "__qualname__")), + cls)}; return py::make_tuple(qualname, obj); }); { const scoped_critical_section cs{list}; - py::getattr(list, Py_Get_ID(sort))(py::arg("key") = sort_key_fn); + py::getattr(list, "sort")(py::arg("key") = sort_key_fn); } } catch (py::error_already_set &ex2) { if (ex2.matches(PyExc_TypeError)) [[likely]] { diff --git a/src/optree.cpp b/src/optree.cpp index bec618fa..d9ca8e17 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -299,7 +299,7 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] #endif auto * const PyTreeKind_Type = reinterpret_cast(PyTreeKindTypeObject.ptr()); PyTreeKind_Type->tp_name = "optree.PyTreeKind"; - py::setattr(PyTreeKindTypeObject, Py_Get_ID(__module__), Py_Get_ID(optree)); + py::setattr(PyTreeKindTypeObject, "__module__", py::str("optree")); py::setattr(PyTreeKindTypeObject, "NUM_KINDS", py::int_(py::ssize_t(PyTreeKind::NumKinds))); auto PyTreeSpecTypeObject = @@ -323,7 +323,7 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] py::module_local()); auto * const PyTreeSpec_Type = reinterpret_cast(PyTreeSpecTypeObject.ptr()); PyTreeSpec_Type->tp_name = "optree.PyTreeSpec"; - py::setattr(PyTreeSpecTypeObject, Py_Get_ID(__module__), Py_Get_ID(optree)); + py::setattr(PyTreeSpecTypeObject, "__module__", py::str("optree")); PyTreeSpecTypeObject .def("unflatten", @@ -521,7 +521,7 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] py::module_local()); auto * const PyTreeIter_Type = reinterpret_cast(PyTreeIterTypeObject.ptr()); PyTreeIter_Type->tp_name = "optree.PyTreeIter"; - py::setattr(PyTreeIterTypeObject, Py_Get_ID(__module__), Py_Get_ID(optree)); + py::setattr(PyTreeIterTypeObject, "__module__", py::str("optree")); PyTreeIterTypeObject .def(py::init, bool, std::string>(), diff --git a/src/treespec/constructors.cpp b/src/treespec/constructors.cpp index f8c37e5b..cc4ec090 100644 --- a/src/treespec/constructors.cpp +++ b/src/treespec/constructors.cpp @@ -169,7 +169,7 @@ template node.arity = DictGetSize(dict); keys = DictKeys(dict); if (node.kind != PyTreeKind::OrderedDict) [[likely]] { - node.original_keys = py::getattr(keys, Py_Get_ID(copy))(); + node.original_keys = py::getattr(keys, "copy")(); if (!IsDictInsertionOrdered(registry_namespace)) [[likely]] { TotalOrderSort(keys); } @@ -181,8 +181,8 @@ template verify_children(children, treespecs); if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] { const scoped_critical_section cs{handle}; - node.node_data = py::make_tuple(py::getattr(handle, Py_Get_ID(default_factory)), - std::move(keys)); + node.node_data = + py::make_tuple(py::getattr(handle, "default_factory"), std::move(keys)); } else [[likely]] { node.node_data = std::move(keys); } @@ -204,8 +204,7 @@ template case PyTreeKind::Deque: { const auto list = thread_safe_cast(handle); node.arity = ListGetSize(list); - node.node_data = - EVALUATE_WITH_LOCK_HELD(py::getattr(handle, Py_Get_ID(maxlen)), handle); + node.node_data = EVALUATE_WITH_LOCK_HELD(py::getattr(handle, "maxlen"), handle); for (ssize_t i = 0; i < node.arity; ++i) { children.emplace_back(ListGetItem(list, i)); } diff --git a/src/treespec/flatten.cpp b/src/treespec/flatten.cpp index 2271f96e..8c630224 100644 --- a/src/treespec/flatten.cpp +++ b/src/treespec/flatten.cpp @@ -106,7 +106,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle &handle, node.arity = DictGetSize(dict); keys = DictKeys(dict); if (node.kind != PyTreeKind::OrderedDict) [[likely]] { - node.original_keys = py::getattr(keys, Py_Get_ID(copy))(); + node.original_keys = py::getattr(keys, "copy")(); if constexpr (DictShouldBeSorted) { TotalOrderSort(keys); } @@ -117,8 +117,8 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle &handle, } if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] { const scoped_critical_section cs{handle}; - node.node_data = py::make_tuple(py::getattr(handle, Py_Get_ID(default_factory)), - std::move(keys)); + node.node_data = + py::make_tuple(py::getattr(handle, "default_factory"), std::move(keys)); } else [[likely]] { node.node_data = std::move(keys); } @@ -139,8 +139,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle &handle, case PyTreeKind::Deque: { const auto list = thread_safe_cast(handle); node.arity = ListGetSize(list); - node.node_data = - EVALUATE_WITH_LOCK_HELD(py::getattr(handle, Py_Get_ID(maxlen)), handle); + node.node_data = EVALUATE_WITH_LOCK_HELD(py::getattr(handle, "maxlen"), handle); for (ssize_t i = 0; i < node.arity; ++i) { recurse(ListGetItem(list, i)); } @@ -371,7 +370,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle &handle, node.arity = DictGetSize(dict); py::list keys = DictKeys(dict); if (node.kind != PyTreeKind::OrderedDict) [[likely]] { - node.original_keys = py::getattr(keys, Py_Get_ID(copy))(); + node.original_keys = py::getattr(keys, "copy")(); if constexpr (DictShouldBeSorted) { TotalOrderSort(keys); } @@ -380,8 +379,8 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle &handle, recurse(DictGetItem(dict, key), key); } if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] { - node.node_data = py::make_tuple(py::getattr(handle, Py_Get_ID(default_factory)), - std::move(keys)); + node.node_data = + py::make_tuple(py::getattr(handle, "default_factory"), std::move(keys)); } else [[likely]] { node.node_data = std::move(keys); } @@ -402,8 +401,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle &handle, case PyTreeKind::Deque: { const auto list = thread_safe_cast(handle); node.arity = ListGetSize(list); - node.node_data = - EVALUATE_WITH_LOCK_HELD(py::getattr(handle, Py_Get_ID(maxlen)), handle); + node.node_data = EVALUATE_WITH_LOCK_HELD(py::getattr(handle, "maxlen"), handle); for (ssize_t i = 0; i < node.arity; ++i) { recurse(ListGetItem(list, i), py::int_(i)); } diff --git a/src/treespec/serialization.cpp b/src/treespec/serialization.cpp index 06556588..a0e4249b 100644 --- a/src/treespec/serialization.cpp +++ b/src/treespec/serialization.cpp @@ -141,7 +141,7 @@ std::string PyTreeSpec::ToStringImpl() const { node.arity, "Number of fields and entries does not match."); const std::string kind = - PyStr(EVALUATE_WITH_LOCK_HELD(py::getattr(type, Py_Get_ID(__name__)), type)); + PyStr(EVALUATE_WITH_LOCK_HELD(py::getattr(type, "__name__"), type)); sstream << kind << "("; bool first = true; auto child_it = agenda.cend() - node.arity; @@ -195,9 +195,8 @@ std::string PyTreeSpec::ToStringImpl() const { EXPECT_EQ(TupleGetSize(fields), node.arity, "Number of fields and entries does not match."); - const py::object module_name = EVALUATE_WITH_LOCK_HELD( - py::getattr(type, Py_Get_ID(__module__), Py_Get_ID(__main__)), - type); + const py::object module_name = + EVALUATE_WITH_LOCK_HELD(py::getattr(type, "__module__", py::none()), type); if (!module_name.is_none()) [[likely]] { const std::string name = PyStr(module_name); if (!(name.empty() || name == "__main__" || name == "builtins" || @@ -206,7 +205,7 @@ std::string PyTreeSpec::ToStringImpl() const { } } const py::object qualname = - EVALUATE_WITH_LOCK_HELD(py::getattr(type, Py_Get_ID(__qualname__)), type); + EVALUATE_WITH_LOCK_HELD(py::getattr(type, "__qualname__"), type); sstream << PyStr(qualname) << "("; bool first = true; auto child_it = agenda.cend() - node.arity; @@ -223,9 +222,9 @@ std::string PyTreeSpec::ToStringImpl() const { } case PyTreeKind::Custom: { - const std::string kind = PyStr( - EVALUATE_WITH_LOCK_HELD(py::getattr(node.custom->type, Py_Get_ID(__name__)), - node.custom->type)); + const std::string kind = + PyStr(EVALUATE_WITH_LOCK_HELD(py::getattr(node.custom->type, "__name__"), + node.custom->type)); sstream << "CustomTreeNode(" << kind << "["; if (node.node_data) [[likely]] { sstream << PyRepr(node.node_data); diff --git a/src/treespec/treespec.cpp b/src/treespec/treespec.cpp index 165ee665..599934ef 100644 --- a/src/treespec/treespec.cpp +++ b/src/treespec/treespec.cpp @@ -856,11 +856,11 @@ py::list PyTreeSpec::Entries() const { case PyTreeKind::Dict: case PyTreeKind::OrderedDict: { const scoped_critical_section cs{root.node_data}; - return py::getattr(root.node_data, Py_Get_ID(copy))(); + return py::getattr(root.node_data, "copy")(); } case PyTreeKind::DefaultDict: { const scoped_critical_section cs{root.node_data}; - return py::getattr(TupleGetItem(root.node_data, 1), Py_Get_ID(copy))(); + return py::getattr(TupleGetItem(root.node_data, 1), "copy")(); } case PyTreeKind::NumKinds: From 56d991f2012eb4adb9d79cae0508b289b98b8f9d Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sun, 28 Dec 2025 12:48:38 +0800 Subject: [PATCH 48/59] chore: cleanup dict order namespaces --- include/optree/treespec.h | 26 +++++++----- src/registry.cpp | 21 ++++++++++ src/treespec/flatten.cpp | 4 +- tests/concurrent/test_subinterpreters.py | 7 ++-- tests/concurrent/test_threading.py | 52 ++++++++++++------------ 5 files changed, 68 insertions(+), 42 deletions(-) diff --git a/include/optree/treespec.h b/include/optree/treespec.h index 1ff25883..095354c2 100644 --- a/include/optree/treespec.h +++ b/include/optree/treespec.h @@ -23,7 +23,7 @@ limitations under the License. #include // std::thread::id #include // std::tuple #include // std::unordered_set -#include // std::pair +#include // std::pair, std::make_pair #include // std::vector #include @@ -263,28 +263,31 @@ class PyTreeSpec { [[nodiscard]] static inline Py_ALWAYS_INLINE bool IsDictInsertionOrdered( const std::string ®istry_namespace, const bool &inherit_global_namespace = true) { - const scoped_read_lock lock{sm_is_dict_insertion_ordered_mutex}; + const scoped_read_lock lock{sm_dict_order_mutex}; - return (sm_is_dict_insertion_ordered.find(registry_namespace) != - sm_is_dict_insertion_ordered.end()) || - (inherit_global_namespace && - sm_is_dict_insertion_ordered.find("") != sm_is_dict_insertion_ordered.end()); + const auto interpid = GetCurrentPyInterpreterID(); + const auto &namespaces = sm_dict_insertion_ordered_namespaces; + return (namespaces.find({interpid, registry_namespace}) != namespaces.end()) || + (inherit_global_namespace && namespaces.find({interpid, ""}) != namespaces.end()); } // Set the namespace to preserve the insertion order of the dictionary keys during flattening. static inline Py_ALWAYS_INLINE void SetDictInsertionOrdered( const bool &mode, const std::string ®istry_namespace) { - const scoped_write_lock lock{sm_is_dict_insertion_ordered_mutex}; + const scoped_write_lock lock{sm_dict_order_mutex}; + const auto interpid = GetCurrentPyInterpreterID(); + const auto key = std::make_pair(interpid, registry_namespace); if (mode) [[likely]] { - sm_is_dict_insertion_ordered.insert(registry_namespace); + sm_dict_insertion_ordered_namespaces.insert(key); } else [[unlikely]] { - sm_is_dict_insertion_ordered.erase(registry_namespace); + sm_dict_insertion_ordered_namespaces.erase(key); } } friend void BuildModule(py::module_ &mod); // NOLINT[runtime/references] + friend class PyTreeTypeRegistry; private: using RegistrationPtr = PyTreeTypeRegistry::RegistrationPtr; @@ -426,8 +429,9 @@ class PyTreeSpec { // A set of namespaces that preserve the insertion order of the dictionary keys during // flattening. - static inline std::unordered_set sm_is_dict_insertion_ordered{}; - static inline read_write_mutex sm_is_dict_insertion_ordered_mutex{}; + static inline std::unordered_set> + sm_dict_insertion_ordered_namespaces{}; + static inline read_write_mutex sm_dict_order_mutex{}; }; class PyTreeIter { diff --git a/src/registry.cpp b/src/registry.cpp index f82526bb..9ac51a80 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -341,6 +341,27 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( "The current interpreter ID should be present in the alive interpreters set."); sm_alive_interpids.erase(interpid); + { + const scoped_write_lock namespace_lock{PyTreeSpec::sm_dict_order_mutex}; + using key_type = decltype(PyTreeSpec::sm_dict_insertion_ordered_namespaces)::key_type; + auto &dict_insertion_ordered_namespaces = PyTreeSpec::sm_dict_insertion_ordered_namespaces; + auto entries = reserved_vector(4); + for (const auto &entry : dict_insertion_ordered_namespaces) { + if (entry.first == interpid) [[likely]] { + entries.emplace_back(entry); + } + } + for (const auto &entry : entries) { + dict_insertion_ordered_namespaces.erase(entry); + } + if (sm_alive_interpids.empty()) [[likely]] { + EXPECT_TRUE( + dict_insertion_ordered_namespaces.empty(), + "The dict insertion ordered namespaces map should be empty when there is no " + "alive Python interpreter."); + } + } + auto ®istry1 = GetSingleton(); auto ®istry2 = GetSingleton(); diff --git a/src/treespec/flatten.cpp b/src/treespec/flatten.cpp index 8c630224..2da89583 100644 --- a/src/treespec/flatten.cpp +++ b/src/treespec/flatten.cpp @@ -208,7 +208,7 @@ bool PyTreeSpec::FlattenInto(const py::handle &handle, bool is_dict_insertion_ordered_in_current_namespace = false; { #if defined(HAVE_READ_WRITE_LOCK) - const scoped_read_lock lock{sm_is_dict_insertion_ordered_mutex}; + const scoped_read_lock lock{sm_dict_order_mutex}; #endif is_dict_insertion_ordered = IsDictInsertionOrdered(registry_namespace); is_dict_insertion_ordered_in_current_namespace = @@ -484,7 +484,7 @@ bool PyTreeSpec::FlattenIntoWithPath(const py::handle &handle, bool is_dict_insertion_ordered_in_current_namespace = false; { #if defined(HAVE_READ_WRITE_LOCK) - const scoped_read_lock lock{sm_is_dict_insertion_ordered_mutex}; + const scoped_read_lock lock{sm_dict_order_mutex}; #endif is_dict_insertion_ordered = IsDictInsertionOrdered(registry_namespace); is_dict_insertion_ordered_in_current_namespace = diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py index 41c9dc07..237b2f35 100644 --- a/tests/concurrent/test_subinterpreters.py +++ b/tests/concurrent/test_subinterpreters.py @@ -27,7 +27,6 @@ PYPY, WASM, Py_DEBUG, - Py_GIL_DISABLED, check_script_in_subprocess, ) @@ -48,9 +47,9 @@ from concurrent.futures import InterpreterPoolExecutor, as_completed -if Py_GIL_DISABLED and not Py_DEBUG: - NUM_WORKERS = 32 - NUM_FUTURES = 128 +if not Py_DEBUG: + NUM_WORKERS = 8 + NUM_FUTURES = 32 NUM_FLAKY_RERUNS = 16 else: NUM_WORKERS = 4 diff --git a/tests/concurrent/test_threading.py b/tests/concurrent/test_threading.py index d1888b42..ec172926 100644 --- a/tests/concurrent/test_threading.py +++ b/tests/concurrent/test_threading.py @@ -343,31 +343,33 @@ def test_tree_iter_thread_safe( dict_should_be_sorted, dict_session_namespace, ): - counter = itertools.count() - with optree.dict_insertion_ordered( - not dict_should_be_sorted, - namespace=dict_session_namespace or GLOBAL_NAMESPACE, - ): - new_tree = optree.tree_map( - lambda x: next(counter), - tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - num_leaves = next(counter) - assert optree.tree_leaves( - new_tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) == list(range(num_leaves)) - - it = optree.tree_iter( - new_tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - + def get_iterator(): + counter = itertools.count() + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + new_tree = optree.tree_map( + lambda x: next(counter), + tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + num_leaves = next(counter) + it = optree.tree_iter( + new_tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + return it, new_tree, num_leaves + + it, _, num_leaves = get_iterator() + sentinel = object() + assert list(it) == list(range(num_leaves)) + assert next(it, sentinel) is sentinel + + it, new_tree, _ = get_iterator() results = concurrent_run(list, it) assert sorted(itertools.chain.from_iterable(results)) == list(range(num_leaves)) for seq in results: - assert sorted(seq) == seq + assert seq == sorted(seq), f'Expected {sorted(seq)}, but got {seq}: tree {new_tree!r}.' From 2e5e53f53ef95a5dc436f3360283785a92a20eec Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sun, 28 Dec 2025 19:27:11 +0800 Subject: [PATCH 49/59] fix: fix concurrency issue --- include/optree/treespec.h | 2 -- src/treespec/traversal.cpp | 21 ++++++++++++++------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/include/optree/treespec.h b/include/optree/treespec.h index 095354c2..0261a6da 100644 --- a/include/optree/treespec.h +++ b/include/optree/treespec.h @@ -468,9 +468,7 @@ class PyTreeIter { const bool m_none_is_leaf; const std::string m_namespace; const bool m_is_dict_insertion_ordered; -#if defined(Py_GIL_DISABLED) mutable mutex m_mutex{}; -#endif template [[nodiscard]] py::object NextImpl(); diff --git a/src/treespec/traversal.cpp b/src/treespec/traversal.cpp index e22ba47e..bc379068 100644 --- a/src/treespec/traversal.cpp +++ b/src/treespec/traversal.cpp @@ -162,14 +162,21 @@ py::object PyTreeIter::NextImpl() { } py::object PyTreeIter::Next() { -#if defined(Py_GIL_DISABLED) - const scoped_lock lock{m_mutex}; +#if !defined(Py_GIL_DISABLED) + const py::gil_scoped_release gil_release; #endif - - if (m_none_is_leaf) [[unlikely]] { - return NextImpl(); - } else [[likely]] { - return NextImpl(); + { + const scoped_lock lock{m_mutex}; + { +#if !defined(Py_GIL_DISABLED) + const py::gil_scoped_acquire gil_acquire; +#endif + if (m_none_is_leaf) [[unlikely]] { + return NextImpl(); + } else [[likely]] { + return NextImpl(); + } + } } } From e39789d63462cec0d71e5dc13fe0e1df1a22854c Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sun, 28 Dec 2025 20:51:38 +0800 Subject: [PATCH 50/59] chore: add `[[likely]]` attribute --- src/registry.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/registry.cpp b/src/registry.cpp index 9ac51a80..1061b3ce 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -41,7 +41,7 @@ template "PyTree type " + PyRepr(cls) + " is already registered in the built-in types set."); cls.inc_ref(); - if (!NoneIsLeaf || kind != PyTreeKind::None) { + if (!NoneIsLeaf || kind != PyTreeKind::None) [[likely]] { auto registration = std::make_shared>(); registration->kind = kind; From 389e17c2cedd2cec3884b393169c9dd6f12cdda5 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sun, 28 Dec 2025 23:26:17 +0800 Subject: [PATCH 51/59] fix: fix PyPy --- include/optree/pymacros.h | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/include/optree/pymacros.h b/include/optree/pymacros.h index 41707ff3..f2fb251f 100644 --- a/include/optree/pymacros.h +++ b/include/optree/pymacros.h @@ -73,6 +73,9 @@ inline constexpr Py_ALWAYS_INLINE bool Py_IsConstant(PyObject *x) noexcept { using interpid_t = decltype(PyInterpreterState_GetID(nullptr)); +#if defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \ + NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) + [[nodiscard]] inline bool IsCurrentPyInterpreterMain() { return PyInterpreterState_Get() == PyInterpreterState_Main(); } @@ -106,3 +109,11 @@ using interpid_t = decltype(PyInterpreterState_GetID(nullptr)); } return interpid; } + +#else + +[[nodiscard]] inline bool IsCurrentPyInterpreterMain() noexcept { return true; } +[[nodiscard]] inline interpid_t GetCurrentPyInterpreterID() noexcept { return 0; } +[[nodiscard]] inline interpid_t GetMainPyInterpreterID() noexcept { return 0; } + +#endif From 302db6a716ec9d89339f572a5c2d3f6dc2133c2c Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 29 Dec 2025 00:19:58 +0800 Subject: [PATCH 52/59] chore: use simple GIL --- include/optree/pytypes.h | 10 ++++++++++ src/treespec/traversal.cpp | 20 +++++++------------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/include/optree/pytypes.h b/include/optree/pytypes.h index a932bd9e..5ecdf3e8 100644 --- a/include/optree/pytypes.h +++ b/include/optree/pytypes.h @@ -261,6 +261,7 @@ inline bool IsNamedTupleClass(const py::handle &type) { static read_write_mutex mutex{}; { + const py::gil_scoped_release_simple gil_release{}; const scoped_read_lock lock{mutex}; const auto it = cache.find(type); if (it != cache.end()) [[likely]] { @@ -270,8 +271,10 @@ inline bool IsNamedTupleClass(const py::handle &type) { const bool result = EVALUATE_WITH_LOCK_HELD(IsNamedTupleClassImpl(type), type); { + const py::gil_scoped_release_simple gil_release{}; const scoped_write_lock lock{mutex}; if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] { + const py::gil_scoped_acquire_simple gil_acquire{}; cache.emplace(type, result); (void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void { const scoped_write_lock lock{mutex}; @@ -363,6 +366,7 @@ inline bool IsStructSequenceClass(const py::handle &type) { static read_write_mutex mutex{}; { + const py::gil_scoped_release_simple gil_release{}; const scoped_read_lock lock{mutex}; const auto it = cache.find(type); if (it != cache.end()) [[likely]] { @@ -372,8 +376,10 @@ inline bool IsStructSequenceClass(const py::handle &type) { const bool result = EVALUATE_WITH_LOCK_HELD(IsStructSequenceClassImpl(type), type); { + const py::gil_scoped_release_simple gil_release{}; const scoped_write_lock lock{mutex}; if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] { + const py::gil_scoped_acquire_simple gil_acquire{}; cache.emplace(type, result); (void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void { const scoped_write_lock lock{mutex}; @@ -446,17 +452,21 @@ inline py::tuple StructSequenceGetFields(const py::handle &object) { static read_write_mutex mutex{}; { + const py::gil_scoped_release_simple gil_release{}; const scoped_read_lock lock{mutex}; const auto it = cache.find(type); if (it != cache.end()) [[likely]] { + const py::gil_scoped_acquire_simple gil_acquire{}; return py::reinterpret_borrow(it->second); } } const py::tuple fields = EVALUATE_WITH_LOCK_HELD(StructSequenceGetFieldsImpl(type), type); { + const py::gil_scoped_release_simple gil_release{}; const scoped_write_lock lock{mutex}; if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] { + const py::gil_scoped_acquire_simple gil_acquire{}; cache.emplace(type, fields); fields.inc_ref(); (void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void { diff --git a/src/treespec/traversal.cpp b/src/treespec/traversal.cpp index bc379068..f1fd9482 100644 --- a/src/treespec/traversal.cpp +++ b/src/treespec/traversal.cpp @@ -162,20 +162,14 @@ py::object PyTreeIter::NextImpl() { } py::object PyTreeIter::Next() { -#if !defined(Py_GIL_DISABLED) - const py::gil_scoped_release gil_release; -#endif + const py::gil_scoped_release_simple gil_release{}; + const scoped_lock lock{m_mutex}; { - const scoped_lock lock{m_mutex}; - { -#if !defined(Py_GIL_DISABLED) - const py::gil_scoped_acquire gil_acquire; -#endif - if (m_none_is_leaf) [[unlikely]] { - return NextImpl(); - } else [[likely]] { - return NextImpl(); - } + const py::gil_scoped_acquire_simple gil_acquire{}; + if (m_none_is_leaf) [[unlikely]] { + return NextImpl(); + } else [[likely]] { + return NextImpl(); } } } From 7084688a57423a6a4eff1774375777bbd86f6a31 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 29 Dec 2025 00:34:06 +0800 Subject: [PATCH 53/59] test: update test --- tests/concurrent/test_subinterpreters.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py index 237b2f35..0af8cfdc 100644 --- a/tests/concurrent/test_subinterpreters.py +++ b/tests/concurrent/test_subinterpreters.py @@ -83,6 +83,7 @@ def check_module_importable(): import collections import time + import optree import optree._C is_current_interpreter_main = optree._C.is_current_interpreter_main() From 1deda32481515b6bbbf9260cccf22f7dd2d568f3 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 29 Dec 2025 09:58:02 +0800 Subject: [PATCH 54/59] test: update test timeout --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index edf87196..fc731376 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -70,7 +70,7 @@ jobs: - "3.14t" - "pypy-3.11" fail-fast: false - timeout-minutes: 90 + timeout-minutes: 120 steps: - name: Checkout uses: actions/checkout@v6 From 37f22d95a7336f88124b068f595aabe9c5a09975 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 29 Dec 2025 22:58:13 +0800 Subject: [PATCH 55/59] chore: update macros --- include/optree/pytypes.h | 20 ++++++++++++++++++++ src/treespec/traversal.cpp | 4 ++++ 2 files changed, 24 insertions(+) diff --git a/include/optree/pytypes.h b/include/optree/pytypes.h index 5ecdf3e8..2dfea5f9 100644 --- a/include/optree/pytypes.h +++ b/include/optree/pytypes.h @@ -261,7 +261,9 @@ inline bool IsNamedTupleClass(const py::handle &type) { static read_write_mutex mutex{}; { +#if !defined(Py_GIL_DISABLED) const py::gil_scoped_release_simple gil_release{}; +#endif const scoped_read_lock lock{mutex}; const auto it = cache.find(type); if (it != cache.end()) [[likely]] { @@ -271,10 +273,14 @@ inline bool IsNamedTupleClass(const py::handle &type) { const bool result = EVALUATE_WITH_LOCK_HELD(IsNamedTupleClassImpl(type), type); { +#if !defined(Py_GIL_DISABLED) const py::gil_scoped_release_simple gil_release{}; +#endif const scoped_write_lock lock{mutex}; if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] { +#if !defined(Py_GIL_DISABLED) const py::gil_scoped_acquire_simple gil_acquire{}; +#endif cache.emplace(type, result); (void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void { const scoped_write_lock lock{mutex}; @@ -366,7 +372,9 @@ inline bool IsStructSequenceClass(const py::handle &type) { static read_write_mutex mutex{}; { +#if !defined(Py_GIL_DISABLED) const py::gil_scoped_release_simple gil_release{}; +#endif const scoped_read_lock lock{mutex}; const auto it = cache.find(type); if (it != cache.end()) [[likely]] { @@ -376,10 +384,14 @@ inline bool IsStructSequenceClass(const py::handle &type) { const bool result = EVALUATE_WITH_LOCK_HELD(IsStructSequenceClassImpl(type), type); { +#if !defined(Py_GIL_DISABLED) const py::gil_scoped_release_simple gil_release{}; +#endif const scoped_write_lock lock{mutex}; if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] { +#if !defined(Py_GIL_DISABLED) const py::gil_scoped_acquire_simple gil_acquire{}; +#endif cache.emplace(type, result); (void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void { const scoped_write_lock lock{mutex}; @@ -452,21 +464,29 @@ inline py::tuple StructSequenceGetFields(const py::handle &object) { static read_write_mutex mutex{}; { +#if !defined(Py_GIL_DISABLED) const py::gil_scoped_release_simple gil_release{}; +#endif const scoped_read_lock lock{mutex}; const auto it = cache.find(type); if (it != cache.end()) [[likely]] { +#if !defined(Py_GIL_DISABLED) const py::gil_scoped_acquire_simple gil_acquire{}; +#endif return py::reinterpret_borrow(it->second); } } const py::tuple fields = EVALUATE_WITH_LOCK_HELD(StructSequenceGetFieldsImpl(type), type); { +#if !defined(Py_GIL_DISABLED) const py::gil_scoped_release_simple gil_release{}; +#endif const scoped_write_lock lock{mutex}; if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] { +#if !defined(Py_GIL_DISABLED) const py::gil_scoped_acquire_simple gil_acquire{}; +#endif cache.emplace(type, fields); fields.inc_ref(); (void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void { diff --git a/src/treespec/traversal.cpp b/src/treespec/traversal.cpp index f1fd9482..aea43087 100644 --- a/src/treespec/traversal.cpp +++ b/src/treespec/traversal.cpp @@ -162,10 +162,14 @@ py::object PyTreeIter::NextImpl() { } py::object PyTreeIter::Next() { +#if !defined(Py_GIL_DISABLED) const py::gil_scoped_release_simple gil_release{}; +#endif const scoped_lock lock{m_mutex}; { +#if !defined(Py_GIL_DISABLED) const py::gil_scoped_acquire_simple gil_acquire{}; +#endif if (m_none_is_leaf) [[unlikely]] { return NextImpl(); } else [[likely]] { From a708f32803eaa1ea17b1638fda919a6cb44e57c3 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 31 Dec 2025 15:41:58 +0800 Subject: [PATCH 56/59] refactor: move dict order registry to `PyTreeTypeRegistry` --- include/optree/registry.h | 36 +++++++++++++++++++++++- include/optree/treespec.h | 52 ++++++----------------------------- src/optree.cpp | 4 +-- src/registry.cpp | 11 ++++---- src/treespec/constructors.cpp | 3 +- src/treespec/flatten.cpp | 14 ++++++---- 6 files changed, 61 insertions(+), 59 deletions(-) diff --git a/include/optree/registry.h b/include/optree/registry.h index 841fb8ca..fc3cc624 100644 --- a/include/optree/registry.h +++ b/include/optree/registry.h @@ -23,7 +23,7 @@ limitations under the License. #include // std::string #include // std::unordered_map #include // std::unordered_set -#include // std::pair +#include // std::pair, std::make_pair #include @@ -161,6 +161,33 @@ class PyTreeTypeRegistry { return sm_alive_interpids; } + // Check if should preserve the insertion order of the dictionary keys during flattening. + [[nodiscard]] static inline Py_ALWAYS_INLINE bool IsDictInsertionOrdered( + const std::string ®istry_namespace, + const bool &inherit_global_namespace = true) { + const scoped_read_lock lock{sm_dict_order_mutex}; + + const auto interpid = GetCurrentPyInterpreterID(); + const auto &namespaces = sm_dict_insertion_ordered_namespaces; + return (namespaces.find({interpid, registry_namespace}) != namespaces.end()) || + (inherit_global_namespace && namespaces.find({interpid, ""}) != namespaces.end()); + } + + // Set the namespace to preserve the insertion order of the dictionary keys during flattening. + static inline Py_ALWAYS_INLINE void SetDictInsertionOrdered( + const bool &mode, + const std::string ®istry_namespace) { + const scoped_write_lock lock{sm_dict_order_mutex}; + + const auto interpid = GetCurrentPyInterpreterID(); + const auto key = std::make_pair(interpid, registry_namespace); + if (mode) [[likely]] { + sm_dict_insertion_ordered_namespaces.insert(key); + } else [[unlikely]] { + sm_dict_insertion_ordered_namespaces.erase(key); + } + } + friend void BuildModule(py::module_ &mod); // NOLINT[runtime/references] private: @@ -193,6 +220,13 @@ class PyTreeTypeRegistry { NamedRegistrationsMap m_named_registrations{}; BuiltinsTypesSet m_builtins_types{}; + // A set of namespaces that preserve the insertion order of the dictionary keys during + // flattening. + static inline std::unordered_set> + sm_dict_insertion_ordered_namespaces{}; + static inline read_write_mutex sm_dict_order_mutex{}; + friend class PyTreeSpec; + static inline std::unordered_set sm_alive_interpids{}; static inline read_write_mutex sm_mutex{}; static inline ssize_t sm_num_interpreters_seen = 0; diff --git a/include/optree/treespec.h b/include/optree/treespec.h index 0261a6da..c73684f6 100644 --- a/include/optree/treespec.h +++ b/include/optree/treespec.h @@ -17,14 +17,13 @@ limitations under the License. #pragma once -#include // std::unique_ptr -#include // std::optional, std::nullopt -#include // std::string -#include // std::thread::id -#include // std::tuple -#include // std::unordered_set -#include // std::pair, std::make_pair -#include // std::vector +#include // std::unique_ptr +#include // std::optional, std::nullopt +#include // std::string +#include // std::thread::id +#include // std::tuple +#include // std::pair +#include // std::vector #include @@ -259,35 +258,7 @@ class PyTreeSpec { const bool &none_is_leaf = false, const std::string ®istry_namespace = ""); - // Check if should preserve the insertion order of the dictionary keys during flattening. - [[nodiscard]] static inline Py_ALWAYS_INLINE bool IsDictInsertionOrdered( - const std::string ®istry_namespace, - const bool &inherit_global_namespace = true) { - const scoped_read_lock lock{sm_dict_order_mutex}; - - const auto interpid = GetCurrentPyInterpreterID(); - const auto &namespaces = sm_dict_insertion_ordered_namespaces; - return (namespaces.find({interpid, registry_namespace}) != namespaces.end()) || - (inherit_global_namespace && namespaces.find({interpid, ""}) != namespaces.end()); - } - - // Set the namespace to preserve the insertion order of the dictionary keys during flattening. - static inline Py_ALWAYS_INLINE void SetDictInsertionOrdered( - const bool &mode, - const std::string ®istry_namespace) { - const scoped_write_lock lock{sm_dict_order_mutex}; - - const auto interpid = GetCurrentPyInterpreterID(); - const auto key = std::make_pair(interpid, registry_namespace); - if (mode) [[likely]] { - sm_dict_insertion_ordered_namespaces.insert(key); - } else [[unlikely]] { - sm_dict_insertion_ordered_namespaces.erase(key); - } - } - friend void BuildModule(py::module_ &mod); // NOLINT[runtime/references] - friend class PyTreeTypeRegistry; private: using RegistrationPtr = PyTreeTypeRegistry::RegistrationPtr; @@ -426,12 +397,6 @@ class PyTreeSpec { // Used in tp_clear for GC support. static int PyTpClear(PyObject *self_base); - - // A set of namespaces that preserve the insertion order of the dictionary keys during - // flattening. - static inline std::unordered_set> - sm_dict_insertion_ordered_namespaces{}; - static inline read_write_mutex sm_dict_order_mutex{}; }; class PyTreeIter { @@ -445,7 +410,8 @@ class PyTreeIter { m_leaf_predicate{leaf_predicate}, m_none_is_leaf{none_is_leaf}, m_namespace{registry_namespace}, - m_is_dict_insertion_ordered{PyTreeSpec::IsDictInsertionOrdered(registry_namespace)} {} + m_is_dict_insertion_ordered{ + PyTreeTypeRegistry::IsDictInsertionOrdered(registry_namespace)} {} PyTreeIter() = delete; ~PyTreeIter() = default; diff --git a/src/optree.cpp b/src/optree.cpp index d9ca8e17..9a3ae3a4 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -146,12 +146,12 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] py::pos_only(), py::arg("namespace") = "") .def("is_dict_insertion_ordered", - &PyTreeSpec::IsDictInsertionOrdered, + &PyTreeTypeRegistry::IsDictInsertionOrdered, "Return whether need to preserve the dict insertion order during flattening.", py::arg("namespace") = "", py::arg("inherit_global_namespace") = true) .def("set_dict_insertion_ordered", - &PyTreeSpec::SetDictInsertionOrdered, + &PyTreeTypeRegistry::SetDictInsertionOrdered, "Set whether need to preserve the dict insertion order during flattening.", py::arg("mode"), py::pos_only(), diff --git a/src/registry.cpp b/src/registry.cpp index 1061b3ce..079312bd 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -342,21 +342,20 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( sm_alive_interpids.erase(interpid); { - const scoped_write_lock namespace_lock{PyTreeSpec::sm_dict_order_mutex}; - using key_type = decltype(PyTreeSpec::sm_dict_insertion_ordered_namespaces)::key_type; - auto &dict_insertion_ordered_namespaces = PyTreeSpec::sm_dict_insertion_ordered_namespaces; + const scoped_write_lock namespace_lock{sm_dict_order_mutex}; + using key_type = decltype(sm_dict_insertion_ordered_namespaces)::key_type; auto entries = reserved_vector(4); - for (const auto &entry : dict_insertion_ordered_namespaces) { + for (const auto &entry : sm_dict_insertion_ordered_namespaces) { if (entry.first == interpid) [[likely]] { entries.emplace_back(entry); } } for (const auto &entry : entries) { - dict_insertion_ordered_namespaces.erase(entry); + sm_dict_insertion_ordered_namespaces.erase(entry); } if (sm_alive_interpids.empty()) [[likely]] { EXPECT_TRUE( - dict_insertion_ordered_namespaces.empty(), + sm_dict_insertion_ordered_namespaces.empty(), "The dict insertion ordered namespaces map should be empty when there is no " "alive Python interpreter."); } diff --git a/src/treespec/constructors.cpp b/src/treespec/constructors.cpp index cc4ec090..49c861ca 100644 --- a/src/treespec/constructors.cpp +++ b/src/treespec/constructors.cpp @@ -170,7 +170,8 @@ template keys = DictKeys(dict); if (node.kind != PyTreeKind::OrderedDict) [[likely]] { node.original_keys = py::getattr(keys, "copy")(); - if (!IsDictInsertionOrdered(registry_namespace)) [[likely]] { + if (!PyTreeTypeRegistry::IsDictInsertionOrdered(registry_namespace)) + [[likely]] { TotalOrderSort(keys); } } diff --git a/src/treespec/flatten.cpp b/src/treespec/flatten.cpp index 2da89583..44aa6677 100644 --- a/src/treespec/flatten.cpp +++ b/src/treespec/flatten.cpp @@ -208,11 +208,12 @@ bool PyTreeSpec::FlattenInto(const py::handle &handle, bool is_dict_insertion_ordered_in_current_namespace = false; { #if defined(HAVE_READ_WRITE_LOCK) - const scoped_read_lock lock{sm_dict_order_mutex}; + const scoped_read_lock lock{PyTreeTypeRegistry::sm_dict_order_mutex}; #endif - is_dict_insertion_ordered = IsDictInsertionOrdered(registry_namespace); + is_dict_insertion_ordered = PyTreeTypeRegistry::IsDictInsertionOrdered(registry_namespace); is_dict_insertion_ordered_in_current_namespace = - IsDictInsertionOrdered(registry_namespace, /*inherit_global_namespace=*/false); + PyTreeTypeRegistry::IsDictInsertionOrdered(registry_namespace, + /*inherit_global_namespace=*/false); } if (none_is_leaf) [[unlikely]] { @@ -484,11 +485,12 @@ bool PyTreeSpec::FlattenIntoWithPath(const py::handle &handle, bool is_dict_insertion_ordered_in_current_namespace = false; { #if defined(HAVE_READ_WRITE_LOCK) - const scoped_read_lock lock{sm_dict_order_mutex}; + const scoped_read_lock lock{PyTreeTypeRegistry::sm_dict_order_mutex}; #endif - is_dict_insertion_ordered = IsDictInsertionOrdered(registry_namespace); + is_dict_insertion_ordered = PyTreeTypeRegistry::IsDictInsertionOrdered(registry_namespace); is_dict_insertion_ordered_in_current_namespace = - IsDictInsertionOrdered(registry_namespace, /*inherit_global_namespace=*/false); + PyTreeTypeRegistry::IsDictInsertionOrdered(registry_namespace, + /*inherit_global_namespace=*/false); } auto stack = reserved_vector(4); From 99cc0b1465dac81692fc7240062eacf4c55e0999 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 1 Jan 2026 00:51:21 +0800 Subject: [PATCH 57/59] chore: handle refcount --- src/registry.cpp | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/registry.cpp b/src/registry.cpp index 079312bd..4e1a97ff 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -40,7 +40,6 @@ template EXPECT_TRUE(registry.m_builtins_types.emplace(cls).second, "PyTree type " + PyRepr(cls) + " is already registered in the built-in types set."); - cls.inc_ref(); if (!NoneIsLeaf || kind != PyTreeKind::None) [[likely]] { auto registration = std::make_shared>(); @@ -50,9 +49,9 @@ template registry.m_registrations.emplace(cls, std::move(registration)).second, "PyTree type " + PyRepr(cls) + " is already registered in the global namespace."); - if constexpr (!NoneIsLeaf) { - cls.inc_ref(); - } + } + if constexpr (!NoneIsLeaf) { + cls.inc_ref(); } }; add_builtin_type(PyNoneTypeObject, PyTreeKind::None); @@ -343,8 +342,7 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( { const scoped_write_lock namespace_lock{sm_dict_order_mutex}; - using key_type = decltype(sm_dict_insertion_ordered_namespaces)::key_type; - auto entries = reserved_vector(4); + auto entries = reserved_vector(4); for (const auto &entry : sm_dict_insertion_ordered_namespaces) { if (entry.first == interpid) [[likely]] { entries.emplace_back(entry); @@ -403,9 +401,6 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( } #endif - for (const auto &cls : registry1.m_builtins_types) { - cls.dec_ref(); - } for (const auto &[_, registration1] : registry1.m_registrations) { registration1->type.dec_ref(); registration1->flatten_func.dec_ref(); From 60b5afcb89bcb8587654944602a7ff1d7a24fe66 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 1 Jan 2026 01:03:52 +0800 Subject: [PATCH 58/59] docs(CHANGELOG): update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7be653bd..27ff93e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Add subinterpreters support for Python 3.14+ by [@XuehaiPan](https://github.com/XuehaiPan) in [#245](https://github.com/metaopt/optree/pull/245). ### Changed From 76ffa627223a837027567f8a091fb75cf0b0e7cd Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 7 Jan 2026 20:08:09 +0800 Subject: [PATCH 59/59] chore: trigger CI