Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
edc21c7
chore(pre-commit): update pre-commit hooks
XuehaiPan Nov 29, 2025
d5a1cc4
test: reorgainize test files
XuehaiPan Oct 8, 2025
8bd5fde
test: add subinterpreters tests
XuehaiPan Oct 8, 2025
23abd08
chore: reorder dependencies
XuehaiPan Nov 29, 2025
65a8fce
test: add import tests
XuehaiPan Nov 29, 2025
261729a
chore(workflows): update find options
XuehaiPan Nov 29, 2025
6b9b38a
feat: set `py::multiple_interpreters::shared_gil()`
XuehaiPan Nov 30, 2025
e4f2019
feat: set `py::multiple_interpreters::per_interpreter_gil()`
XuehaiPan Nov 30, 2025
eec4602
chore: disable subinterpreters for pre-3.14
XuehaiPan Dec 1, 2025
fade81a
chore(pre-commit): [pre-commit.ci] autoupdate
pre-commit-ci[bot] Dec 1, 2025
e7a2a89
refactor: simplify per-interpreter registry
XuehaiPan Dec 1, 2025
1fd07c8
fix: fix collections type for subinterpreters
XuehaiPan Dec 1, 2025
209801d
style: add `[[nodiscard]]` attributes
XuehaiPan Dec 2, 2025
02648ab
chore(pre-commit): update pre-commit hooks
XuehaiPan Dec 6, 2025
a32c672
chore: sync changes
XuehaiPan Dec 6, 2025
c026062
chore: sync changes
XuehaiPan Dec 6, 2025
3ba5f29
chore: sync changes
XuehaiPan Dec 6, 2025
278d920
Merge branch 'pre-commit-ci-update-config' into subinterpreters
XuehaiPan Dec 6, 2025
51f6878
Merge branch 'main' into subinterpreters
XuehaiPan Dec 8, 2025
e276ec3
chore: set macro `OPTREE_HAS_SUBINTERPRETER_SUPPORT`
XuehaiPan Dec 8, 2025
86aa6f6
fix: fix find command
XuehaiPan Dec 8, 2025
f955df0
refactor: change pointers to references
XuehaiPan Dec 8, 2025
e2463e2
test: add import tests
XuehaiPan Dec 8, 2025
be1fb8e
fix: do not use `py::gil_safe_call_once_and_store` for subinterpreters
XuehaiPan Dec 8, 2025
8d5c8bf
Merge branch 'main' into subinterpreters
XuehaiPan Dec 8, 2025
3495e72
chore: reorgainize code
XuehaiPan Dec 8, 2025
2278070
feat: add function `get_main_interpreter_id()`
XuehaiPan Dec 8, 2025
8c36df2
chore: reorgainize code
XuehaiPan Dec 10, 2025
b7d9173
fix: fix docs dependency resolving
XuehaiPan Dec 10, 2025
d25f6dc
fix(workflows/lint): fix docs dependency resolving
XuehaiPan Dec 10, 2025
af30820
Merge branch 'main' into subinterpreters
XuehaiPan Dec 10, 2025
0801449
fix(workflows/lint): fix docs dependency resolving
XuehaiPan Dec 10, 2025
8ac7a96
Merge branch 'main' into subinterpreters
XuehaiPan Dec 10, 2025
e119a2f
chore: update nightly pybind11 url
XuehaiPan Dec 14, 2025
5dcc2f8
feat: improve sanity check error messages
XuehaiPan Dec 14, 2025
1c43acc
revert
XuehaiPan Dec 14, 2025
90afec2
update
XuehaiPan Dec 14, 2025
ce661a2
chore: split ci jobs
XuehaiPan Dec 15, 2025
afae64c
fix: fix repr for exception
XuehaiPan Dec 15, 2025
5c3ba09
Merge branch 'main' into subinterpreters
XuehaiPan Dec 17, 2025
5fb4769
test: skip failed tests
XuehaiPan Dec 17, 2025
e486981
test: set no-cov for subinterpreter tests
XuehaiPan Dec 17, 2025
d89e120
test: set env for subprocess
XuehaiPan Dec 18, 2025
5efad7e
chore: split tests
XuehaiPan Dec 18, 2025
bb0a4d9
chore(pre-commit): update pre-commit hooks
XuehaiPan Dec 19, 2025
b7cb1cf
test: enable subinterpreter tests
XuehaiPan Dec 19, 2025
3fb200e
Merge remote-tracking branch 'upstream/main' into subinterpreters
XuehaiPan Dec 21, 2025
f7abc85
chore: update nightly remote
XuehaiPan Dec 25, 2025
ea8cc2c
chore: add more build time meta
XuehaiPan Dec 25, 2025
2935452
chore: update nightly remote
XuehaiPan Dec 26, 2025
fb08dec
chore: update test
XuehaiPan Dec 27, 2025
0d18d19
chore: update test
XuehaiPan Dec 27, 2025
3996bb0
chore: update macros
XuehaiPan Dec 27, 2025
c420862
chore: remove `Py_Get_ID`
XuehaiPan Dec 27, 2025
56d991f
chore: cleanup dict order namespaces
XuehaiPan Dec 28, 2025
2e5e53f
fix: fix concurrency issue
XuehaiPan Dec 28, 2025
e39789d
chore: add `[[likely]]` attribute
XuehaiPan Dec 28, 2025
389e17c
fix: fix PyPy
XuehaiPan Dec 28, 2025
302db6a
chore: use simple GIL
XuehaiPan Dec 28, 2025
7084688
test: update test
XuehaiPan Dec 28, 2025
1deda32
test: update test timeout
XuehaiPan Dec 29, 2025
37f22d9
chore: update macros
XuehaiPan Dec 29, 2025
a708f32
refactor: move dict order registry to `PyTreeTypeRegistry`
XuehaiPan Dec 31, 2025
99cc0b1
chore: handle refcount
XuehaiPan Dec 31, 2025
60b5afc
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Dec 31, 2025
0907ca9
Merge remote-tracking branch 'upstream/main' into subinterpreters
XuehaiPan Jan 3, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/tests-with-pydebug.yml
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,13 @@ jobs:
"--cov-report=xml:coverage-${{ env.PYTHON_TAG }}-${{ runner.os }}.xml"
"--junit-xml=junit-${{ env.PYTHON_TAG }}-${{ runner.os }}.xml"
)
make test PYTESTOPTS="${PYTESTOPTS[*]}"

if ${{ env.PYTHON }} -c 'import sys, optree; sys.exit(not optree._C.OPTREE_HAS_SUBINTERPRETER_SUPPORT)'; then
make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'concurrent' --no-cov"
make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'not subinterpreter'"
else
make test PYTESTOPTS="${PYTESTOPTS[*]}"
fi

CORE_DUMP_FILES="$(
find . -type d -path "./venv" -prune \
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Add subinterpreters support for Python 3.14+ by [@XuehaiPan](https://github.com/XuehaiPan) in [#245](https://github.com/metaopt/optree/pull/245).

### Changed

Expand Down
58 changes: 58 additions & 0 deletions include/optree/pymacros.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.

#pragma once

#include <stdexcept> // std::runtime_error

#include <Python.h>

#include <pybind11/pybind11.h>
Expand All @@ -32,6 +34,15 @@ limitations under the License.
// NOLINTNEXTLINE[bugprone-macro-parentheses]
#define NONZERO_OR_EMPTY(MACRO) ((MACRO + 0 != 0) || (0 - MACRO - 1 >= 0))

#if !defined(PYPY_VERSION) && (PY_VERSION_HEX >= 0x030E0000 /* Python 3.14 */) && \
(PYBIND11_VERSION_HEX >= 0x030002A0 /* pybind11 3.0.2.a0 */) && \
(defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \
NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT))
# define OPTREE_HAS_SUBINTERPRETER_SUPPORT 1
#else
# undef OPTREE_HAS_SUBINTERPRETER_SUPPORT
#endif

namespace py = pybind11;

#if !defined(Py_ALWAYS_INLINE)
Expand Down Expand Up @@ -59,3 +70,50 @@ inline constexpr Py_ALWAYS_INLINE bool Py_IsConstant(PyObject *x) noexcept {
return Py_IsNone(x) || Py_IsTrue(x) || Py_IsFalse(x);
}
#define Py_IsConstant(x) Py_IsConstant(x)

using interpid_t = decltype(PyInterpreterState_GetID(nullptr));

#if defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \
NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT)

[[nodiscard]] inline bool IsCurrentPyInterpreterMain() {
return PyInterpreterState_Get() == PyInterpreterState_Main();
}

[[nodiscard]] inline interpid_t GetCurrentPyInterpreterID() {
PyInterpreterState *interp = PyInterpreterState_Get();
if (PyErr_Occurred() != nullptr) [[unlikely]] {
throw py::error_already_set();
}
if (interp == nullptr) [[unlikely]] {
throw std::runtime_error("Failed to get the current Python interpreter state.");
}
const interpid_t interpid = PyInterpreterState_GetID(interp);
if (PyErr_Occurred() != nullptr) [[unlikely]] {
throw py::error_already_set();
}
return interpid;
}

[[nodiscard]] inline interpid_t GetMainPyInterpreterID() {
PyInterpreterState *interp = PyInterpreterState_Main();
if (PyErr_Occurred() != nullptr) [[unlikely]] {
throw py::error_already_set();
}
if (interp == nullptr) [[unlikely]] {
throw std::runtime_error("Failed to get the main Python interpreter state.");
}
const interpid_t interpid = PyInterpreterState_GetID(interp);
if (PyErr_Occurred() != nullptr) [[unlikely]] {
throw py::error_already_set();
}
return interpid;
}

#else

[[nodiscard]] inline bool IsCurrentPyInterpreterMain() noexcept { return true; }
[[nodiscard]] inline interpid_t GetCurrentPyInterpreterID() noexcept { return 0; }
[[nodiscard]] inline interpid_t GetMainPyInterpreterID() noexcept { return 0; }

#endif
58 changes: 57 additions & 1 deletion include/optree/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ limitations under the License.
#include <string> // std::string
#include <unordered_map> // std::unordered_map
#include <unordered_set> // std::unordered_set
#include <utility> // std::pair
#include <utility> // std::pair, std::make_pair

#include <pybind11/pybind11.h>

#include "optree/exceptions.h"
#include "optree/hashing.h"
#include "optree/pymacros.h"
#include "optree/synchronization.h"

namespace optree {
Expand Down Expand Up @@ -141,6 +142,52 @@ class PyTreeTypeRegistry {
return count1;
}

// Get the number of alive interpreters that have seen the registry.
[[nodiscard]] static inline Py_ALWAYS_INLINE ssize_t GetNumInterpretersAlive() {
const scoped_read_lock lock{sm_mutex};
return py::ssize_t_cast(sm_alive_interpids.size());
}

// Get the number of interpreters that have seen the registry.
[[nodiscard]] static inline Py_ALWAYS_INLINE ssize_t GetNumInterpretersSeen() {
const scoped_read_lock lock{sm_mutex};
return sm_num_interpreters_seen;
}

// Get the IDs of alive interpreters that have seen the registry.
[[nodiscard]] static inline Py_ALWAYS_INLINE std::unordered_set<interpid_t>
GetAliveInterpreterIDs() {
const scoped_read_lock lock{sm_mutex};
return sm_alive_interpids;
}

// Check if should preserve the insertion order of the dictionary keys during flattening.
[[nodiscard]] static inline Py_ALWAYS_INLINE bool IsDictInsertionOrdered(
const std::string &registry_namespace,
const bool &inherit_global_namespace = true) {
const scoped_read_lock lock{sm_dict_order_mutex};

const auto interpid = GetCurrentPyInterpreterID();
const auto &namespaces = sm_dict_insertion_ordered_namespaces;
return (namespaces.find({interpid, registry_namespace}) != namespaces.end()) ||
(inherit_global_namespace && namespaces.find({interpid, ""}) != namespaces.end());
}

// Set the namespace to preserve the insertion order of the dictionary keys during flattening.
static inline Py_ALWAYS_INLINE void SetDictInsertionOrdered(
const bool &mode,
const std::string &registry_namespace) {
const scoped_write_lock lock{sm_dict_order_mutex};

const auto interpid = GetCurrentPyInterpreterID();
const auto key = std::make_pair(interpid, registry_namespace);
if (mode) [[likely]] {
sm_dict_insertion_ordered_namespaces.insert(key);
} else [[unlikely]] {
sm_dict_insertion_ordered_namespaces.erase(key);
}
}

friend void BuildModule(py::module_ &mod); // NOLINT[runtime/references]

private:
Expand Down Expand Up @@ -173,7 +220,16 @@ class PyTreeTypeRegistry {
NamedRegistrationsMap m_named_registrations{};
BuiltinsTypesSet m_builtins_types{};

// A set of namespaces that preserve the insertion order of the dictionary keys during
// flattening.
static inline std::unordered_set<std::pair<interpid_t, std::string>>
sm_dict_insertion_ordered_namespaces{};
static inline read_write_mutex sm_dict_order_mutex{};
friend class PyTreeSpec;

static inline std::unordered_set<interpid_t> sm_alive_interpids{};
static inline read_write_mutex sm_mutex{};
static inline ssize_t sm_num_interpreters_seen = 0;
};

} // namespace optree
48 changes: 9 additions & 39 deletions include/optree/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@ limitations under the License.

#pragma once

#include <memory> // std::unique_ptr
#include <optional> // std::optional, std::nullopt
#include <string> // std::string
#include <thread> // std::thread::id
#include <tuple> // std::tuple
#include <unordered_set> // std::unordered_set
#include <utility> // std::pair
#include <vector> // std::vector
#include <memory> // std::unique_ptr
#include <optional> // std::optional, std::nullopt
#include <string> // std::string
#include <thread> // std::thread::id
#include <tuple> // std::tuple
#include <utility> // std::pair
#include <vector> // std::vector

#include <pybind11/pybind11.h>

Expand Down Expand Up @@ -259,31 +258,6 @@ class PyTreeSpec {
const bool &none_is_leaf = false,
const std::string &registry_namespace = "");

// Check if should preserve the insertion order of the dictionary keys during flattening.
[[nodiscard]] static inline Py_ALWAYS_INLINE bool IsDictInsertionOrdered(
const std::string &registry_namespace,
const bool &inherit_global_namespace = true) {
const scoped_read_lock lock{sm_is_dict_insertion_ordered_mutex};

return (sm_is_dict_insertion_ordered.find(registry_namespace) !=
sm_is_dict_insertion_ordered.end()) ||
(inherit_global_namespace &&
sm_is_dict_insertion_ordered.find("") != sm_is_dict_insertion_ordered.end());
}

// Set the namespace to preserve the insertion order of the dictionary keys during flattening.
static inline Py_ALWAYS_INLINE void SetDictInsertionOrdered(
const bool &mode,
const std::string &registry_namespace) {
const scoped_write_lock lock{sm_is_dict_insertion_ordered_mutex};

if (mode) [[likely]] {
sm_is_dict_insertion_ordered.insert(registry_namespace);
} else [[unlikely]] {
sm_is_dict_insertion_ordered.erase(registry_namespace);
}
}

friend void BuildModule(py::module_ &mod); // NOLINT[runtime/references]

private:
Expand Down Expand Up @@ -423,11 +397,6 @@ class PyTreeSpec {

// Used in tp_clear for GC support.
static int PyTpClear(PyObject *self_base);

// A set of namespaces that preserve the insertion order of the dictionary keys during
// flattening.
static inline std::unordered_set<std::string> sm_is_dict_insertion_ordered{};
static inline read_write_mutex sm_is_dict_insertion_ordered_mutex{};
};

class PyTreeIter {
Expand All @@ -441,7 +410,8 @@ class PyTreeIter {
m_leaf_predicate{leaf_predicate},
m_none_is_leaf{none_is_leaf},
m_namespace{registry_namespace},
m_is_dict_insertion_ordered{PyTreeSpec::IsDictInsertionOrdered(registry_namespace)} {}
m_is_dict_insertion_ordered{
PyTreeTypeRegistry::IsDictInsertionOrdered(registry_namespace)} {}

PyTreeIter() = delete;
~PyTreeIter() = default;
Expand Down
9 changes: 9 additions & 0 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,13 @@ Py_DEBUG: Final[bool]
Py_GIL_DISABLED: Final[bool]
PYBIND11_VERSION_HEX: Final[int]
PYBIND11_INTERNALS_VERSION: Final[int]
PYBIND11_INTERNALS_ID: Final[str]
PYBIND11_MODULE_LOCAL_ID: Final[str]
PYBIND11_HAS_NATIVE_ENUM: Final[bool]
PYBIND11_HAS_INTERNALS_WITH_SMART_HOLDER_SUPPORT: Final[bool]
PYBIND11_HAS_SUBINTERPRETER_SUPPORT: Final[bool]
GLIBCXX_USE_CXX11_ABI: Final[bool]
OPTREE_HAS_SUBINTERPRETER_SUPPORT: Final[bool]

@final
class InternalError(SystemError): ...
Expand Down Expand Up @@ -214,3 +217,9 @@ def set_dict_insertion_ordered(
namespace: str = '',
) -> None: ...
def get_registry_size(namespace: str | None = None) -> int: ...
def get_num_interpreters_seen() -> int: ...
def get_num_interpreters_alive() -> int: ...
def get_alive_interpreter_ids() -> set[int]: ...
def is_current_interpreter_main() -> bool: ...
def get_current_interpreter_id() -> int: ...
def get_main_interpreter_id() -> int: ...
33 changes: 31 additions & 2 deletions src/optree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references]
#endif
BUILDTIME_METADATA["PYBIND11_VERSION_HEX"] = py::int_(PYBIND11_VERSION_HEX);
BUILDTIME_METADATA["PYBIND11_INTERNALS_VERSION"] = py::int_(PYBIND11_INTERNALS_VERSION);
BUILDTIME_METADATA["PYBIND11_INTERNALS_ID"] = py::str(PYBIND11_INTERNALS_ID);
BUILDTIME_METADATA["PYBIND11_MODULE_LOCAL_ID"] = py::str(PYBIND11_MODULE_LOCAL_ID);
#if defined(PYBIND11_HAS_NATIVE_ENUM) && NONZERO_OR_EMPTY(PYBIND11_HAS_NATIVE_ENUM)
BUILDTIME_METADATA["PYBIND11_HAS_NATIVE_ENUM"] = py::bool_(true);
#else
Expand All @@ -95,6 +97,11 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references]
#else
BUILDTIME_METADATA["GLIBCXX_USE_CXX11_ABI"] = py::bool_(false);
#endif
#if defined(OPTREE_HAS_SUBINTERPRETER_SUPPORT)
BUILDTIME_METADATA["OPTREE_HAS_SUBINTERPRETER_SUPPORT"] = py::bool_(true);
#else
BUILDTIME_METADATA["OPTREE_HAS_SUBINTERPRETER_SUPPORT"] = py::bool_(false);
#endif

mod.attr("BUILDTIME_METADATA") = std::move(BUILDTIME_METADATA);
py::exec(
Expand Down Expand Up @@ -139,12 +146,12 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references]
py::pos_only(),
py::arg("namespace") = "")
.def("is_dict_insertion_ordered",
&PyTreeSpec::IsDictInsertionOrdered,
&PyTreeTypeRegistry::IsDictInsertionOrdered,
"Return whether need to preserve the dict insertion order during flattening.",
py::arg("namespace") = "",
py::arg("inherit_global_namespace") = true)
.def("set_dict_insertion_ordered",
&PyTreeSpec::SetDictInsertionOrdered,
&PyTreeTypeRegistry::SetDictInsertionOrdered,
"Set whether need to preserve the dict insertion order during flattening.",
py::arg("mode"),
py::pos_only(),
Expand All @@ -153,6 +160,24 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references]
&PyTreeTypeRegistry::GetRegistrySize,
"Get the number of registered types.",
py::arg("namespace") = std::nullopt)
.def("get_num_interpreters_seen",
&PyTreeTypeRegistry::GetNumInterpretersSeen,
"Get the number of interpreters that have seen the registry.")
.def("get_num_interpreters_alive",
&PyTreeTypeRegistry::GetNumInterpretersAlive,
"Get the number of alive interpreters that have seen the registry.")
.def("get_alive_interpreter_ids",
&PyTreeTypeRegistry::GetAliveInterpreterIDs,
"Get the IDs of alive interpreters that have seen the registry.")
.def("is_current_interpreter_main",
&IsCurrentPyInterpreterMain,
"Check whether the current interpreter is the main interpreter.")
.def("get_current_interpreter_id",
&GetCurrentPyInterpreterID,
"Get the ID of the current interpreter.")
.def("get_main_interpreter_id",
&GetMainPyInterpreterID,
"Get the ID of the main interpreter.")
.def("flatten",
&PyTreeSpec::Flatten,
"Flatten a pytree.",
Expand Down Expand Up @@ -528,7 +553,11 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references]

// NOLINTBEGIN[cppcoreguidelines-pro-bounds-pointer-arithmetic,cppcoreguidelines-pro-type-vararg]
#if PYBIND11_VERSION_HEX >= 0x020D00F0 // pybind11 2.13.0
# if defined(OPTREE_HAS_SUBINTERPRETER_SUPPORT)
PYBIND11_MODULE(_C, mod, py::mod_gil_not_used(), py::multiple_interpreters::per_interpreter_gil())
# else
PYBIND11_MODULE(_C, mod, py::mod_gil_not_used())
# endif
#else
PYBIND11_MODULE(_C, mod)
#endif
Expand Down
Loading
Loading