diff --git a/README.md b/README.md
index a4c21f2..61c7e41 100644
--- a/README.md
+++ b/README.md
@@ -7,7 +7,7 @@
-
+
diff --git a/vicinity/backends/__init__.py b/vicinity/backends/__init__.py
index 28782a9..e615edf 100644
--- a/vicinity/backends/__init__.py
+++ b/vicinity/backends/__init__.py
@@ -1,3 +1,4 @@
+from importlib.util import find_spec
from typing import Union
from vicinity.backends.base import AbstractBackend
@@ -5,35 +6,59 @@
from vicinity.datatypes import Backend
+class OptionalDependencyError(ImportError):
+ def __init__(self, backend: Backend, extra: str) -> None:
+ msg = f"{backend} requires extra '{extra}'.\n" f"Install it with: pip install 'vicinity[{extra}]'\n"
+ super().__init__(msg)
+ self.backend = backend
+ self.extra = extra
+
+
+def _require(module_name: str, backend: Backend, extra: str) -> None:
+ """Check if a dependency is importable, otherwise raise an error."""
+ if find_spec(module_name) is None:
+ raise OptionalDependencyError(backend, extra)
+
+
def get_backend_class(backend: Union[Backend, str]) -> type[AbstractBackend]:
- """Get all available backends."""
+ """Get the requested backend and ensure its dependencies are installed."""
backend = Backend(backend)
+
if backend == Backend.BASIC:
return BasicBackend
+
elif backend == Backend.HNSW:
+ _require("hnswlib", backend, "hnsw")
from vicinity.backends.hnsw import HNSWBackend
return HNSWBackend
+
elif backend == Backend.ANNOY:
+ _require("annoy", backend, "annoy")
from vicinity.backends.annoy import AnnoyBackend
return AnnoyBackend
+
elif backend == Backend.PYNNDESCENT:
+ _require("pynndescent", backend, "pynndescent")
from vicinity.backends.pynndescent import PyNNDescentBackend
return PyNNDescentBackend
elif backend == Backend.FAISS:
+ _require("faiss", backend, "faiss")
from vicinity.backends.faiss import FaissBackend
return FaissBackend
elif backend == Backend.USEARCH:
+ _require("usearch", backend, "usearch")
from vicinity.backends.usearch import UsearchBackend
return UsearchBackend
elif backend == Backend.VOYAGER:
+ _require("voyager", backend, "voyager")
from vicinity.backends.voyager import VoyagerBackend
return VoyagerBackend
diff --git a/vicinity/version.py b/vicinity/version.py
index 29d357c..56a9515 100644
--- a/vicinity/version.py
+++ b/vicinity/version.py
@@ -1,2 +1,2 @@
-__version_triple__ = (0, 4, 2)
+__version_triple__ = (0, 4, 3)
__version__ = ".".join(map(str, __version_triple__))