diff --git a/README.md b/README.md index a4c21f2..61c7e41 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@

Package version - Supported Python versions + Docs Downloads diff --git a/vicinity/backends/__init__.py b/vicinity/backends/__init__.py index 28782a9..e615edf 100644 --- a/vicinity/backends/__init__.py +++ b/vicinity/backends/__init__.py @@ -1,3 +1,4 @@ +from importlib.util import find_spec from typing import Union from vicinity.backends.base import AbstractBackend @@ -5,35 +6,59 @@ from vicinity.datatypes import Backend +class OptionalDependencyError(ImportError): + def __init__(self, backend: Backend, extra: str) -> None: + msg = f"{backend} requires extra '{extra}'.\n" f"Install it with: pip install 'vicinity[{extra}]'\n" + super().__init__(msg) + self.backend = backend + self.extra = extra + + +def _require(module_name: str, backend: Backend, extra: str) -> None: + """Check if a dependency is importable, otherwise raise an error.""" + if find_spec(module_name) is None: + raise OptionalDependencyError(backend, extra) + + def get_backend_class(backend: Union[Backend, str]) -> type[AbstractBackend]: - """Get all available backends.""" + """Get the requested backend and ensure its dependencies are installed.""" backend = Backend(backend) + if backend == Backend.BASIC: return BasicBackend + elif backend == Backend.HNSW: + _require("hnswlib", backend, "hnsw") from vicinity.backends.hnsw import HNSWBackend return HNSWBackend + elif backend == Backend.ANNOY: + _require("annoy", backend, "annoy") from vicinity.backends.annoy import AnnoyBackend return AnnoyBackend + elif backend == Backend.PYNNDESCENT: + _require("pynndescent", backend, "pynndescent") from vicinity.backends.pynndescent import PyNNDescentBackend return PyNNDescentBackend elif backend == Backend.FAISS: + _require("faiss", backend, "faiss") from vicinity.backends.faiss import FaissBackend return FaissBackend elif backend == Backend.USEARCH: + _require("usearch", backend, "usearch") from vicinity.backends.usearch import UsearchBackend return UsearchBackend elif backend == Backend.VOYAGER: + _require("voyager", backend, "voyager") from vicinity.backends.voyager import VoyagerBackend return VoyagerBackend diff --git a/vicinity/version.py b/vicinity/version.py index 29d357c..56a9515 100644 --- a/vicinity/version.py +++ b/vicinity/version.py @@ -1,2 +1,2 @@ -__version_triple__ = (0, 4, 2) +__version_triple__ = (0, 4, 3) __version__ = ".".join(map(str, __version_triple__))