diff --git a/openml/extensions/functions.py b/openml/extensions/functions.py index 7a944c997..06902325e 100644 --- a/openml/extensions/functions.py +++ b/openml/extensions/functions.py @@ -1,6 +1,7 @@ # License: BSD 3-Clause from __future__ import annotations +import importlib.util from typing import TYPE_CHECKING, Any # Need to implement the following by its full path because otherwise it won't be possible to @@ -16,8 +17,9 @@ SKLEARN_HINT = ( "But it looks related to scikit-learn. " "Please install the OpenML scikit-learn extension (openml-sklearn) and try again. " + "You can use `pip install openml-sklearn` for installation." "For more information, see " - "https://github.com/openml/openml-sklearn?tab=readme-ov-file#installation" + "https://docs.openml.org/python/extensions/" ) @@ -58,6 +60,10 @@ def get_extension_by_flow( ------- Extension or None """ + # import openml_sklearn to register SklearnExtension + if importlib.util.find_spec("openml_sklearn"): + import openml_sklearn # noqa: F401 + candidates = [] for extension_class in openml.extensions.extensions: if extension_class.can_handle_flow(flow): @@ -103,6 +109,10 @@ def get_extension_by_model( ------- Extension or None """ + # import openml_sklearn to register SklearnExtension + if importlib.util.find_spec("openml_sklearn"): + import openml_sklearn # noqa: F401 + candidates = [] for extension_class in openml.extensions.extensions: if extension_class.can_handle_model(model):