diff --git a/.github/workflows/set_setup_requires.py b/.github/workflows/set_setup_requires.py index df07c011..afd817c9 100755 --- a/.github/workflows/set_setup_requires.py +++ b/.github/workflows/set_setup_requires.py @@ -9,7 +9,7 @@ ROOT = Path(__file__).absolute().parents[2] PYPROJECT_FILE = ROOT / 'pyproject.toml' -PYBIND11_GIT_URL = 'https://github.com/pybind/pybind11.git' +PYBIND11_GIT_URL = 'https://github.com/XuehaiPan/pybind11.git' if __name__ == '__main__': diff --git a/.github/workflows/tests-with-pydebug.yml b/.github/workflows/tests-with-pydebug.yml index af9530bd..2573e64e 100644 --- a/.github/workflows/tests-with-pydebug.yml +++ b/.github/workflows/tests-with-pydebug.yml @@ -347,7 +347,13 @@ jobs: "--cov-report=xml:coverage-${{ env.PYTHON_TAG }}-${{ runner.os }}.xml" "--junit-xml=junit-${{ env.PYTHON_TAG }}-${{ runner.os }}.xml" ) - make test PYTESTOPTS="${PYTESTOPTS[*]}" + + if ${{ env.PYTHON }} -c 'import sys, optree; sys.exit(not optree._C.OPTREE_HAS_SUBINTERPRETER_SUPPORT)'; then + make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'concurrent' --no-cov" + make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'not subinterpreter'" + else + make test PYTESTOPTS="${PYTESTOPTS[*]}" + fi CORE_DUMP_FILES="$( find . -type d -path "./venv" -prune \ diff --git a/CHANGELOG.md b/CHANGELOG.md index 7be653bd..27ff93e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Add subinterpreters support for Python 3.14+ by [@XuehaiPan](https://github.com/XuehaiPan) in [#245](https://github.com/metaopt/optree/pull/245). ### Changed diff --git a/include/optree/pymacros.h b/include/optree/pymacros.h index 56e5cc01..f2fb251f 100644 --- a/include/optree/pymacros.h +++ b/include/optree/pymacros.h @@ -17,6 +17,8 @@ limitations under the License. #pragma once +#include // std::runtime_error + #include #include @@ -32,6 +34,15 @@ limitations under the License. // NOLINTNEXTLINE[bugprone-macro-parentheses] #define NONZERO_OR_EMPTY(MACRO) ((MACRO + 0 != 0) || (0 - MACRO - 1 >= 0)) +#if !defined(PYPY_VERSION) && (PY_VERSION_HEX >= 0x030E0000 /* Python 3.14 */) && \ + (PYBIND11_VERSION_HEX >= 0x030002A0 /* pybind11 3.0.2.a0 */) && \ + (defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \ + NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT)) +# define OPTREE_HAS_SUBINTERPRETER_SUPPORT 1 +#else +# undef OPTREE_HAS_SUBINTERPRETER_SUPPORT +#endif + namespace py = pybind11; #if !defined(Py_ALWAYS_INLINE) @@ -59,3 +70,50 @@ inline constexpr Py_ALWAYS_INLINE bool Py_IsConstant(PyObject *x) noexcept { return Py_IsNone(x) || Py_IsTrue(x) || Py_IsFalse(x); } #define Py_IsConstant(x) Py_IsConstant(x) + +using interpid_t = decltype(PyInterpreterState_GetID(nullptr)); + +#if defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \ + NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) + +[[nodiscard]] inline bool IsCurrentPyInterpreterMain() { + return PyInterpreterState_Get() == PyInterpreterState_Main(); +} + +[[nodiscard]] inline interpid_t GetCurrentPyInterpreterID() { + PyInterpreterState *interp = PyInterpreterState_Get(); + if (PyErr_Occurred() != nullptr) [[unlikely]] { + throw py::error_already_set(); + } + if (interp == nullptr) [[unlikely]] { + throw std::runtime_error("Failed to get the current Python interpreter state."); + } + const interpid_t interpid = PyInterpreterState_GetID(interp); + if (PyErr_Occurred() != nullptr) [[unlikely]] { + throw py::error_already_set(); + } + return interpid; +} + +[[nodiscard]] inline interpid_t GetMainPyInterpreterID() { + PyInterpreterState *interp = PyInterpreterState_Main(); + if (PyErr_Occurred() != nullptr) [[unlikely]] { + throw py::error_already_set(); + } + if (interp == nullptr) [[unlikely]] { + throw std::runtime_error("Failed to get the main Python interpreter state."); + } + const interpid_t interpid = PyInterpreterState_GetID(interp); + if (PyErr_Occurred() != nullptr) [[unlikely]] { + throw py::error_already_set(); + } + return interpid; +} + +#else + +[[nodiscard]] inline bool IsCurrentPyInterpreterMain() noexcept { return true; } +[[nodiscard]] inline interpid_t GetCurrentPyInterpreterID() noexcept { return 0; } +[[nodiscard]] inline interpid_t GetMainPyInterpreterID() noexcept { return 0; } + +#endif diff --git a/include/optree/registry.h b/include/optree/registry.h index ac91146d..fc3cc624 100644 --- a/include/optree/registry.h +++ b/include/optree/registry.h @@ -23,12 +23,13 @@ limitations under the License. #include // std::string #include // std::unordered_map #include // std::unordered_set -#include // std::pair +#include // std::pair, std::make_pair #include #include "optree/exceptions.h" #include "optree/hashing.h" +#include "optree/pymacros.h" #include "optree/synchronization.h" namespace optree { @@ -141,6 +142,52 @@ class PyTreeTypeRegistry { return count1; } + // Get the number of alive interpreters that have seen the registry. + [[nodiscard]] static inline Py_ALWAYS_INLINE ssize_t GetNumInterpretersAlive() { + const scoped_read_lock lock{sm_mutex}; + return py::ssize_t_cast(sm_alive_interpids.size()); + } + + // Get the number of interpreters that have seen the registry. + [[nodiscard]] static inline Py_ALWAYS_INLINE ssize_t GetNumInterpretersSeen() { + const scoped_read_lock lock{sm_mutex}; + return sm_num_interpreters_seen; + } + + // Get the IDs of alive interpreters that have seen the registry. + [[nodiscard]] static inline Py_ALWAYS_INLINE std::unordered_set + GetAliveInterpreterIDs() { + const scoped_read_lock lock{sm_mutex}; + return sm_alive_interpids; + } + + // Check if should preserve the insertion order of the dictionary keys during flattening. + [[nodiscard]] static inline Py_ALWAYS_INLINE bool IsDictInsertionOrdered( + const std::string ®istry_namespace, + const bool &inherit_global_namespace = true) { + const scoped_read_lock lock{sm_dict_order_mutex}; + + const auto interpid = GetCurrentPyInterpreterID(); + const auto &namespaces = sm_dict_insertion_ordered_namespaces; + return (namespaces.find({interpid, registry_namespace}) != namespaces.end()) || + (inherit_global_namespace && namespaces.find({interpid, ""}) != namespaces.end()); + } + + // Set the namespace to preserve the insertion order of the dictionary keys during flattening. + static inline Py_ALWAYS_INLINE void SetDictInsertionOrdered( + const bool &mode, + const std::string ®istry_namespace) { + const scoped_write_lock lock{sm_dict_order_mutex}; + + const auto interpid = GetCurrentPyInterpreterID(); + const auto key = std::make_pair(interpid, registry_namespace); + if (mode) [[likely]] { + sm_dict_insertion_ordered_namespaces.insert(key); + } else [[unlikely]] { + sm_dict_insertion_ordered_namespaces.erase(key); + } + } + friend void BuildModule(py::module_ &mod); // NOLINT[runtime/references] private: @@ -173,7 +220,16 @@ class PyTreeTypeRegistry { NamedRegistrationsMap m_named_registrations{}; BuiltinsTypesSet m_builtins_types{}; + // A set of namespaces that preserve the insertion order of the dictionary keys during + // flattening. + static inline std::unordered_set> + sm_dict_insertion_ordered_namespaces{}; + static inline read_write_mutex sm_dict_order_mutex{}; + friend class PyTreeSpec; + + static inline std::unordered_set sm_alive_interpids{}; static inline read_write_mutex sm_mutex{}; + static inline ssize_t sm_num_interpreters_seen = 0; }; } // namespace optree diff --git a/include/optree/treespec.h b/include/optree/treespec.h index ecc6079e..c73684f6 100644 --- a/include/optree/treespec.h +++ b/include/optree/treespec.h @@ -17,14 +17,13 @@ limitations under the License. #pragma once -#include // std::unique_ptr -#include // std::optional, std::nullopt -#include // std::string -#include // std::thread::id -#include // std::tuple -#include // std::unordered_set -#include // std::pair -#include // std::vector +#include // std::unique_ptr +#include // std::optional, std::nullopt +#include // std::string +#include // std::thread::id +#include // std::tuple +#include // std::pair +#include // std::vector #include @@ -259,31 +258,6 @@ class PyTreeSpec { const bool &none_is_leaf = false, const std::string ®istry_namespace = ""); - // Check if should preserve the insertion order of the dictionary keys during flattening. - [[nodiscard]] static inline Py_ALWAYS_INLINE bool IsDictInsertionOrdered( - const std::string ®istry_namespace, - const bool &inherit_global_namespace = true) { - const scoped_read_lock lock{sm_is_dict_insertion_ordered_mutex}; - - return (sm_is_dict_insertion_ordered.find(registry_namespace) != - sm_is_dict_insertion_ordered.end()) || - (inherit_global_namespace && - sm_is_dict_insertion_ordered.find("") != sm_is_dict_insertion_ordered.end()); - } - - // Set the namespace to preserve the insertion order of the dictionary keys during flattening. - static inline Py_ALWAYS_INLINE void SetDictInsertionOrdered( - const bool &mode, - const std::string ®istry_namespace) { - const scoped_write_lock lock{sm_is_dict_insertion_ordered_mutex}; - - if (mode) [[likely]] { - sm_is_dict_insertion_ordered.insert(registry_namespace); - } else [[unlikely]] { - sm_is_dict_insertion_ordered.erase(registry_namespace); - } - } - friend void BuildModule(py::module_ &mod); // NOLINT[runtime/references] private: @@ -423,11 +397,6 @@ class PyTreeSpec { // Used in tp_clear for GC support. static int PyTpClear(PyObject *self_base); - - // A set of namespaces that preserve the insertion order of the dictionary keys during - // flattening. - static inline std::unordered_set sm_is_dict_insertion_ordered{}; - static inline read_write_mutex sm_is_dict_insertion_ordered_mutex{}; }; class PyTreeIter { @@ -441,7 +410,8 @@ class PyTreeIter { m_leaf_predicate{leaf_predicate}, m_none_is_leaf{none_is_leaf}, m_namespace{registry_namespace}, - m_is_dict_insertion_ordered{PyTreeSpec::IsDictInsertionOrdered(registry_namespace)} {} + m_is_dict_insertion_ordered{ + PyTreeTypeRegistry::IsDictInsertionOrdered(registry_namespace)} {} PyTreeIter() = delete; ~PyTreeIter() = default; diff --git a/optree/_C.pyi b/optree/_C.pyi index e7b8217b..f791c9e7 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -49,10 +49,13 @@ Py_DEBUG: Final[bool] Py_GIL_DISABLED: Final[bool] PYBIND11_VERSION_HEX: Final[int] PYBIND11_INTERNALS_VERSION: Final[int] +PYBIND11_INTERNALS_ID: Final[str] +PYBIND11_MODULE_LOCAL_ID: Final[str] PYBIND11_HAS_NATIVE_ENUM: Final[bool] PYBIND11_HAS_INTERNALS_WITH_SMART_HOLDER_SUPPORT: Final[bool] PYBIND11_HAS_SUBINTERPRETER_SUPPORT: Final[bool] GLIBCXX_USE_CXX11_ABI: Final[bool] +OPTREE_HAS_SUBINTERPRETER_SUPPORT: Final[bool] @final class InternalError(SystemError): ... @@ -214,3 +217,9 @@ def set_dict_insertion_ordered( namespace: str = '', ) -> None: ... def get_registry_size(namespace: str | None = None) -> int: ... +def get_num_interpreters_seen() -> int: ... +def get_num_interpreters_alive() -> int: ... +def get_alive_interpreter_ids() -> set[int]: ... +def is_current_interpreter_main() -> bool: ... +def get_current_interpreter_id() -> int: ... +def get_main_interpreter_id() -> int: ... diff --git a/src/optree.cpp b/src/optree.cpp index 54779037..9a3ae3a4 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -73,6 +73,8 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] #endif BUILDTIME_METADATA["PYBIND11_VERSION_HEX"] = py::int_(PYBIND11_VERSION_HEX); BUILDTIME_METADATA["PYBIND11_INTERNALS_VERSION"] = py::int_(PYBIND11_INTERNALS_VERSION); + BUILDTIME_METADATA["PYBIND11_INTERNALS_ID"] = py::str(PYBIND11_INTERNALS_ID); + BUILDTIME_METADATA["PYBIND11_MODULE_LOCAL_ID"] = py::str(PYBIND11_MODULE_LOCAL_ID); #if defined(PYBIND11_HAS_NATIVE_ENUM) && NONZERO_OR_EMPTY(PYBIND11_HAS_NATIVE_ENUM) BUILDTIME_METADATA["PYBIND11_HAS_NATIVE_ENUM"] = py::bool_(true); #else @@ -95,6 +97,11 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] #else BUILDTIME_METADATA["GLIBCXX_USE_CXX11_ABI"] = py::bool_(false); #endif +#if defined(OPTREE_HAS_SUBINTERPRETER_SUPPORT) + BUILDTIME_METADATA["OPTREE_HAS_SUBINTERPRETER_SUPPORT"] = py::bool_(true); +#else + BUILDTIME_METADATA["OPTREE_HAS_SUBINTERPRETER_SUPPORT"] = py::bool_(false); +#endif mod.attr("BUILDTIME_METADATA") = std::move(BUILDTIME_METADATA); py::exec( @@ -139,12 +146,12 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] py::pos_only(), py::arg("namespace") = "") .def("is_dict_insertion_ordered", - &PyTreeSpec::IsDictInsertionOrdered, + &PyTreeTypeRegistry::IsDictInsertionOrdered, "Return whether need to preserve the dict insertion order during flattening.", py::arg("namespace") = "", py::arg("inherit_global_namespace") = true) .def("set_dict_insertion_ordered", - &PyTreeSpec::SetDictInsertionOrdered, + &PyTreeTypeRegistry::SetDictInsertionOrdered, "Set whether need to preserve the dict insertion order during flattening.", py::arg("mode"), py::pos_only(), @@ -153,6 +160,24 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] &PyTreeTypeRegistry::GetRegistrySize, "Get the number of registered types.", py::arg("namespace") = std::nullopt) + .def("get_num_interpreters_seen", + &PyTreeTypeRegistry::GetNumInterpretersSeen, + "Get the number of interpreters that have seen the registry.") + .def("get_num_interpreters_alive", + &PyTreeTypeRegistry::GetNumInterpretersAlive, + "Get the number of alive interpreters that have seen the registry.") + .def("get_alive_interpreter_ids", + &PyTreeTypeRegistry::GetAliveInterpreterIDs, + "Get the IDs of alive interpreters that have seen the registry.") + .def("is_current_interpreter_main", + &IsCurrentPyInterpreterMain, + "Check whether the current interpreter is the main interpreter.") + .def("get_current_interpreter_id", + &GetCurrentPyInterpreterID, + "Get the ID of the current interpreter.") + .def("get_main_interpreter_id", + &GetMainPyInterpreterID, + "Get the ID of the main interpreter.") .def("flatten", &PyTreeSpec::Flatten, "Flatten a pytree.", @@ -528,7 +553,11 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] // NOLINTBEGIN[cppcoreguidelines-pro-bounds-pointer-arithmetic,cppcoreguidelines-pro-type-vararg] #if PYBIND11_VERSION_HEX >= 0x020D00F0 // pybind11 2.13.0 +# if defined(OPTREE_HAS_SUBINTERPRETER_SUPPORT) +PYBIND11_MODULE(_C, mod, py::mod_gil_not_used(), py::multiple_interpreters::per_interpreter_gil()) +# else PYBIND11_MODULE(_C, mod, py::mod_gil_not_used()) +# endif #else PYBIND11_MODULE(_C, mod) #endif diff --git a/src/registry.cpp b/src/registry.cpp index faa6c662..4e1a97ff 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -310,6 +310,14 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( /*static*/ void PyTreeTypeRegistry::Init() { const scoped_write_lock lock{sm_mutex}; + const auto interpid = GetCurrentPyInterpreterID(); + + ++sm_num_interpreters_seen; + EXPECT_TRUE( + sm_alive_interpids.insert(interpid).second, + "The current interpreter ID should not be already present in the alive interpreters " + "set."); + auto ®istry1 = GetSingleton(); auto ®istry2 = GetSingleton(); @@ -325,6 +333,32 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( /*static*/ void PyTreeTypeRegistry::Clear() { const scoped_write_lock lock{sm_mutex}; + const auto interpid = GetCurrentPyInterpreterID(); + + EXPECT_NE(sm_alive_interpids.find(interpid), + sm_alive_interpids.end(), + "The current interpreter ID should be present in the alive interpreters set."); + sm_alive_interpids.erase(interpid); + + { + const scoped_write_lock namespace_lock{sm_dict_order_mutex}; + auto entries = reserved_vector(4); + for (const auto &entry : sm_dict_insertion_ordered_namespaces) { + if (entry.first == interpid) [[likely]] { + entries.emplace_back(entry); + } + } + for (const auto &entry : entries) { + sm_dict_insertion_ordered_namespaces.erase(entry); + } + if (sm_alive_interpids.empty()) [[likely]] { + EXPECT_TRUE( + sm_dict_insertion_ordered_namespaces.empty(), + "The dict insertion ordered namespaces map should be empty when there is no " + "alive Python interpreter."); + } + } + auto ®istry1 = GetSingleton(); auto ®istry2 = GetSingleton(); diff --git a/src/treespec/constructors.cpp b/src/treespec/constructors.cpp index cc4ec090..49c861ca 100644 --- a/src/treespec/constructors.cpp +++ b/src/treespec/constructors.cpp @@ -170,7 +170,8 @@ template keys = DictKeys(dict); if (node.kind != PyTreeKind::OrderedDict) [[likely]] { node.original_keys = py::getattr(keys, "copy")(); - if (!IsDictInsertionOrdered(registry_namespace)) [[likely]] { + if (!PyTreeTypeRegistry::IsDictInsertionOrdered(registry_namespace)) + [[likely]] { TotalOrderSort(keys); } } diff --git a/src/treespec/flatten.cpp b/src/treespec/flatten.cpp index 8c630224..44aa6677 100644 --- a/src/treespec/flatten.cpp +++ b/src/treespec/flatten.cpp @@ -208,11 +208,12 @@ bool PyTreeSpec::FlattenInto(const py::handle &handle, bool is_dict_insertion_ordered_in_current_namespace = false; { #if defined(HAVE_READ_WRITE_LOCK) - const scoped_read_lock lock{sm_is_dict_insertion_ordered_mutex}; + const scoped_read_lock lock{PyTreeTypeRegistry::sm_dict_order_mutex}; #endif - is_dict_insertion_ordered = IsDictInsertionOrdered(registry_namespace); + is_dict_insertion_ordered = PyTreeTypeRegistry::IsDictInsertionOrdered(registry_namespace); is_dict_insertion_ordered_in_current_namespace = - IsDictInsertionOrdered(registry_namespace, /*inherit_global_namespace=*/false); + PyTreeTypeRegistry::IsDictInsertionOrdered(registry_namespace, + /*inherit_global_namespace=*/false); } if (none_is_leaf) [[unlikely]] { @@ -484,11 +485,12 @@ bool PyTreeSpec::FlattenIntoWithPath(const py::handle &handle, bool is_dict_insertion_ordered_in_current_namespace = false; { #if defined(HAVE_READ_WRITE_LOCK) - const scoped_read_lock lock{sm_is_dict_insertion_ordered_mutex}; + const scoped_read_lock lock{PyTreeTypeRegistry::sm_dict_order_mutex}; #endif - is_dict_insertion_ordered = IsDictInsertionOrdered(registry_namespace); + is_dict_insertion_ordered = PyTreeTypeRegistry::IsDictInsertionOrdered(registry_namespace); is_dict_insertion_ordered_in_current_namespace = - IsDictInsertionOrdered(registry_namespace, /*inherit_global_namespace=*/false); + PyTreeTypeRegistry::IsDictInsertionOrdered(registry_namespace, + /*inherit_global_namespace=*/false); } auto stack = reserved_vector(4); diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py new file mode 100644 index 00000000..0af8cfdc --- /dev/null +++ b/tests/concurrent/test_subinterpreters.py @@ -0,0 +1,345 @@ +# Copyright 2022-2025 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import atexit +import contextlib +import random +import sys + +import pytest + +from helpers import ( + ANDROID, + IOS, + OPTREE_HAS_SUBINTERPRETER_SUPPORT, + PYPY, + WASM, + Py_DEBUG, + check_script_in_subprocess, +) + + +if ( + PYPY + or WASM + or IOS + or ANDROID + or sys.version_info < (3, 14) + or not getattr(sys.implementation, 'supports_isolated_interpreters', False) + or not OPTREE_HAS_SUBINTERPRETER_SUPPORT +): + pytest.skip('Test for CPython 3.14+ only', allow_module_level=True) + + +from concurrent import interpreters +from concurrent.futures import InterpreterPoolExecutor, as_completed + + +if not Py_DEBUG: + NUM_WORKERS = 8 + NUM_FUTURES = 32 + NUM_FLAKY_RERUNS = 16 +else: + NUM_WORKERS = 4 + NUM_FUTURES = 16 + NUM_FLAKY_RERUNS = 8 + + +EXECUTOR = InterpreterPoolExecutor(max_workers=NUM_WORKERS) +atexit.register(EXECUTOR.shutdown) + + +def run(func, /, *args, **kwargs): + future = EXECUTOR.submit(func, *args, **kwargs) + exception = future.exception() + if exception is not None: + raise exception + return future.result() + + +def concurrent_run(func, /, *args, **kwargs): + futures = [EXECUTOR.submit(func, *args, **kwargs) for _ in range(NUM_FUTURES)] + future2index = {future: i for i, future in enumerate(futures)} + completed_futures = sorted(as_completed(futures), key=future2index.get) + first_exception = next(filter(None, (future.exception() for future in completed_futures)), None) + if first_exception is not None: + raise first_exception + return [future.result() for future in completed_futures] + + +def check_module_importable(): + import collections + import time + + import optree + import optree._C + + is_current_interpreter_main = optree._C.is_current_interpreter_main() + main_interpreter_id = optree._C.get_main_interpreter_id() + current_interpreter_id = optree._C.get_current_interpreter_id() + + if is_current_interpreter_main != (main_interpreter_id == current_interpreter_id): + raise RuntimeError('interpreter identity mismatch') + + if not is_current_interpreter_main and optree._C.get_registry_size() != 8: + raise RuntimeError('registry size mismatch') + + tree = { + 'b': [2, (3, 4)], + 'a': 1, + 'c': collections.OrderedDict( + f=None, + d=5, + e=time.struct_time([6] + [None] * (time.struct_time.n_sequence_fields - 1)), + ), + 'g': collections.defaultdict(list, h=collections.deque([7, 8, 9], maxlen=10)), + } + + leaves1, treespec1 = optree.tree_flatten(tree, none_is_leaf=False) + reconstructed1 = optree.tree_unflatten(treespec1, leaves1) + if reconstructed1 != tree: + raise RuntimeError('unflatten/flatten mismatch') + if treespec1.num_leaves != len(leaves1): + raise RuntimeError(f'num_leaves mismatch: ({leaves1}, {treespec1})') + if leaves1 != [1, 2, 3, 4, 5, 6, 7, 8, 9]: + raise RuntimeError(f'flattened leaves mismatch: ({leaves1}, {treespec1})') + + leaves2, treespec2 = optree.tree_flatten(tree, none_is_leaf=True) + reconstructed2 = optree.tree_unflatten(treespec2, leaves2) + if reconstructed2 != tree: + raise RuntimeError('unflatten/flatten mismatch') + if treespec2.num_leaves != len(leaves2): + raise RuntimeError(f'num_leaves mismatch: ({leaves2}, {treespec2})') + if leaves2 != [ + 1, + 2, + 3, + 4, + None, + 5, + 6, + *([None] * (time.struct_time.n_sequence_fields - 1)), + 7, + 8, + 9, + ]: + raise RuntimeError(f'flattened leaves mismatch: ({leaves2}, {treespec2})') + + _ = optree.tree_flatten_with_path(tree, none_is_leaf=False) + _ = optree.tree_flatten_with_path(tree, none_is_leaf=True) + _ = optree.tree_flatten_with_accessor(tree, none_is_leaf=False) + _ = optree.tree_flatten_with_accessor(tree, none_is_leaf=True) + + return ( + is_current_interpreter_main, + main_interpreter_id, + id(type(None)), + id(tuple), + id(list), + id(dict), + id(collections.OrderedDict), + ) + + +def test_import(): + import collections + + expected = ( + False, + 0, + id(type(None)), + id(tuple), + id(list), + id(dict), + id(collections.OrderedDict), + ) + + assert check_module_importable() == (True, *expected[1:]) + assert run(check_module_importable) == expected + + for _ in range(random.randint(5, 10)): + with contextlib.closing(interpreters.create()) as subinterpreter: + subinterpreter.exec('import optree') + with contextlib.closing(interpreters.create()) as subinterpreter: + assert subinterpreter.call(check_module_importable) == expected + + for actual in concurrent_run(check_module_importable): + assert actual == expected + + with contextlib.ExitStack() as stack: + subinterpreters = [ + stack.enter_context(contextlib.closing(interpreters.create())) + for _ in range(random.randint(5, 10)) + ] + random.shuffle(subinterpreters) + for subinterpreter in subinterpreters: + subinterpreter.exec('import optree') + + with contextlib.ExitStack() as stack: + subinterpreters = [ + stack.enter_context(contextlib.closing(interpreters.create())) + for _ in range(random.randint(5, 10)) + ] + random.shuffle(subinterpreters) + for subinterpreter in subinterpreters: + assert subinterpreter.call(check_module_importable) == expected + + +def test_import_in_subinterpreter_after_main(): + check_script_in_subprocess( + """ + import contextlib + import gc + from concurrent import interpreters + + import optree + + subinterpreter = None + with contextlib.closing(interpreters.create()) as subinterpreter: + subinterpreter.exec('import optree') + + del optree, subinterpreter + for _ in range(10): + gc.collect() + """, + output='', + rerun=NUM_FLAKY_RERUNS, + ) + + check_script_in_subprocess( + f""" + import contextlib + import gc + import random + from concurrent import interpreters + + import optree + + subinterpreter = subinterpreters = stack = None + with contextlib.ExitStack() as stack: + subinterpreters = [ + stack.enter_context(contextlib.closing(interpreters.create())) + for _ in range({NUM_FUTURES}) + ] + random.shuffle(subinterpreters) + for subinterpreter in subinterpreters: + subinterpreter.exec('import optree') + + del optree, subinterpreter, subinterpreters, stack + for _ in range(10): + gc.collect() + """, + output='', + rerun=NUM_FLAKY_RERUNS, + ) + + +def test_import_in_subinterpreter_before_main(): + check_script_in_subprocess( + """ + import contextlib + import gc + from concurrent import interpreters + + subinterpreter = None + with contextlib.closing(interpreters.create()) as subinterpreter: + subinterpreter.exec('import optree') + + import optree + + del optree, subinterpreter + for _ in range(10): + gc.collect() + """, + output='', + rerun=NUM_FLAKY_RERUNS, + ) + + check_script_in_subprocess( + f""" + import contextlib + import gc + import random + from concurrent import interpreters + + subinterpreter = subinterpreters = stack = None + with contextlib.ExitStack() as stack: + subinterpreters = [ + stack.enter_context(contextlib.closing(interpreters.create())) + for _ in range({NUM_FUTURES}) + ] + random.shuffle(subinterpreters) + for subinterpreter in subinterpreters: + subinterpreter.exec('import optree') + + import optree + + del optree, subinterpreter, subinterpreters, stack + for _ in range(10): + gc.collect() + """, + output='', + rerun=NUM_FLAKY_RERUNS, + ) + + check_script_in_subprocess( + f""" + import contextlib + import gc + import random + from concurrent import interpreters + + subinterpreter = subinterpreters = stack = None + with contextlib.ExitStack() as stack: + subinterpreters = [ + stack.enter_context(contextlib.closing(interpreters.create())) + for _ in range({NUM_FUTURES}) + ] + random.shuffle(subinterpreters) + for subinterpreter in subinterpreters: + subinterpreter.exec('import optree') + + import optree + + del optree, subinterpreter, subinterpreters, stack + for _ in range(10): + gc.collect() + """, + output='', + rerun=NUM_FLAKY_RERUNS, + ) + + +def test_import_in_subinterpreters_concurrently(): + check_script_in_subprocess( + f""" + from concurrent.futures import InterpreterPoolExecutor, as_completed + + def check_import(): + import optree + + if optree._C.get_registry_size() != 8: + raise RuntimeError('registry size mismatch') + if optree._C.is_current_interpreter_main(): + raise RuntimeError('expected subinterpreter') + + with InterpreterPoolExecutor(max_workers={NUM_WORKERS}) as executor: + futures = [executor.submit(check_import) for _ in range({NUM_FUTURES})] + for future in as_completed(futures): + future.result() + """, + output='', + rerun=NUM_FLAKY_RERUNS, + ) diff --git a/tests/helpers.py b/tests/helpers.py index c068ab2e..bf37d176 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -36,6 +36,7 @@ import optree from optree._C import ( + OPTREE_HAS_SUBINTERPRETER_SUPPORT, PYBIND11_HAS_NATIVE_ENUM, PYBIND11_HAS_SUBINTERPRETER_SUPPORT, Py_DEBUG, @@ -55,6 +56,7 @@ _ = PYBIND11_HAS_NATIVE_ENUM _ = PYBIND11_HAS_SUBINTERPRETER_SUPPORT +_ = OPTREE_HAS_SUBINTERPRETER_SUPPORT if sysconfig.get_config_var('Py_DEBUG') is None: assert Py_DEBUG == hasattr(sys, 'gettotalrefcount')