Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion tests/test_encoders.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,30 @@
import unittest
from pathlib import Path

import torch
try:
import torch
except ImportError:
torch = None # type: ignore[assignment]

try:
import core.vision_encoder.pe as pe
import core.vision_encoder.transforms as transforms
except ImportError:
pe = None # type: ignore[assignment]
transforms = None # type: ignore[assignment]

try:
from PIL import Image # noqa: F401
except ImportError:
Image = None # type: ignore[assignment]

from vectorvfs.encoders import PerceptionEncoder


@unittest.skipIf(
torch is None or pe is None or transforms is None or Image is None,
"Perception encoder dependencies are not installed",
)
class TestPerceptionEncoder(unittest.TestCase):
def setUp(self) -> None:
self.data_path = Path(__file__).parent / "data"
Expand Down
36 changes: 36 additions & 0 deletions tests/test_optional_dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest


def test_require_dependencies_provides_descriptive_error(monkeypatch):
import vectorvfs.encoders as encoders

# Simulate the optional stack being absent and make sure the error message is
# explicit enough to guide the user toward the missing extras.
monkeypatch.setattr(encoders, "torch", None)
monkeypatch.setattr(encoders, "pe", None)
monkeypatch.setattr(encoders, "transforms", None)
monkeypatch.setattr(encoders, "Image", None)
monkeypatch.setattr(encoders, "_torch_import_error", ImportError("torch"))
monkeypatch.setattr(encoders, "_encoder_import_error", ImportError("pe"))
monkeypatch.setattr(encoders, "_pillow_import_error", ImportError("pillow"))

with pytest.raises(ImportError, match="PerceptionEncoder requires optional dependencies") as excinfo:
encoders._require_dependencies()

message = str(excinfo.value)
# Ensure the guidance lists every missing optional component so users know
# exactly which extras to install when the perception stack is unavailable.
assert "torch" in message
assert "core.vision_encoder" in message
assert "Pillow" in message


def test_require_torch_provides_descriptive_error(monkeypatch):
import vectorvfs.vfsstore as vfsstore

# Force a missing torch dependency and ensure the guidance is actionable.
monkeypatch.setattr(vfsstore, "torch", None)
monkeypatch.setattr(vfsstore, "_torch_import_error", ImportError("torch"))

with pytest.raises(ImportError, match="VFSStore requires the optional 'torch' dependency"):
vfsstore._require_torch()
84 changes: 74 additions & 10 deletions vectorvfs/encoders.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,69 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING

try: # Optional heavy dependencies
import core.vision_encoder.pe as pe
import core.vision_encoder.transforms as transforms
except ImportError as e: # pragma: no cover - exercised through runtime checks
pe = None # type: ignore[assignment]
transforms = None # type: ignore[assignment]
_encoder_import_error = e
else:
_encoder_import_error = None

try: # Optional heavy dependency
import torch
except ImportError as e: # pragma: no cover - exercised through runtime checks
torch = None # type: ignore[assignment]
_torch_import_error = e
else:
_torch_import_error = None

try: # Optional dependency
from PIL import Image
except ImportError as e: # pragma: no cover - exercised through runtime checks
Image = None # type: ignore[assignment]
_pillow_import_error = e
else:
_pillow_import_error = None

if TYPE_CHECKING: # pragma: no cover - type checking only
from torch import Tensor


def _require_dependencies() -> None:
"""Ensure optional runtime dependencies are available.

Raises a descriptive :class:`ImportError` when the optional encoder stack is
not installed so callers receive an actionable message instead of a vague
failure at import time.
"""

import core.vision_encoder.pe as pe
import core.vision_encoder.transforms as transforms
import torch
from PIL import Image
missing = []
if torch is None:
missing.append("torch")
if pe is None or transforms is None:
missing.append("core.vision_encoder")
if Image is None:
missing.append("Pillow")

if missing:
help_text = (
"PerceptionEncoder requires optional dependencies: "
+ ", ".join(missing)
+ ". Install them to enable vision/text encoding."
)
# Prefer to re-raise the first captured import error for context.
if torch is None and _torch_import_error is not None:
raise ImportError(help_text) from _torch_import_error
if (pe is None or transforms is None) and _encoder_import_error is not None:
raise ImportError(help_text) from _encoder_import_error
if Image is None and _pillow_import_error is not None:
raise ImportError(help_text) from _pillow_import_error
raise ImportError(help_text)


class DualEncoder(ABC):
Expand All @@ -15,7 +74,7 @@ class DualEncoder(ABC):
as well as obtaining the scaling factor for logits in similarity computation.
"""
@abstractmethod
def encode_vision(self, file: Path) -> torch.Tensor:
def encode_vision(self, file: Path) -> "Tensor":
"""
Encode an image file into a tensor representation.

Expand All @@ -25,7 +84,7 @@ def encode_vision(self, file: Path) -> torch.Tensor:
...

@abstractmethod
def encode_text(self, text: str) -> torch.Tensor:
def encode_text(self, text: str) -> "Tensor":
"""
Encode a text string into a tensor representation.

Expand All @@ -35,7 +94,7 @@ def encode_text(self, text: str) -> torch.Tensor:
...

@abstractmethod
def logit_scale(self) -> torch.Tensor:
def logit_scale(self) -> "Tensor":
"""
Get the scale factor applied to logits for similarity computation.

Expand All @@ -51,14 +110,19 @@ class PerceptionEncoder(DualEncoder):
:param model_name: Name of the CLIP model configuration to load (default: "PE-Core-L14-336").
"""
def __init__(self, model_name: str = "PE-Core-L14-336") -> None:
_require_dependencies()

assert torch is not None # for type checkers
assert pe is not None and transforms is not None

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model_name = model_name
self.model = pe.CLIP.from_config(model_name, pretrained=True)
self.model = self.model.to(self.device)
self.preprocess = transforms.get_image_transform(self.model.image_size)
self.tokenizer = transforms.get_text_tokenizer(self.model.context_length)

def encode_vision(self, file: Path) -> torch.Tensor:
def encode_vision(self, file: Path) -> "Tensor":
"""
Encode an image file into a tensor of image features using the perception model.

Expand All @@ -72,7 +136,7 @@ def encode_vision(self, file: Path) -> torch.Tensor:
image_features, _, _ = self.model(image, None)
return image_features

def encode_text(self, text: str) -> torch.Tensor:
def encode_text(self, text: str) -> "Tensor":
"""
Encode a text string into a tensor of text features using the perception model.

Expand All @@ -84,7 +148,7 @@ def encode_text(self, text: str) -> torch.Tensor:
_, text_features, _ = self.model(None, tokenized_text)
return text_features

def logit_scale(self) -> torch.Tensor:
def logit_scale(self) -> "Tensor":
"""
Get the exponential of the model's logit scale parameter for similarity computation.

Expand Down
41 changes: 35 additions & 6 deletions vectorvfs/vfsstore.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,33 @@
from __future__ import annotations

import io
import os
from pathlib import Path
from typing import TYPE_CHECKING

try: # Optional dependency
import torch
except ImportError as e: # pragma: no cover - exercised through runtime checks
torch = None # type: ignore[assignment]
_torch_import_error = e
else:
_torch_import_error = None

if TYPE_CHECKING: # pragma: no cover - type checking only
from torch import Tensor

import torch

def _require_torch() -> None:
"""Raise an informative error when torch is missing at runtime."""

if torch is None:
message = (
"VFSStore requires the optional 'torch' dependency to serialize "
"embeddings. Install torch to enable tensor persistence."
)
if _torch_import_error is not None:
raise ImportError(message) from _torch_import_error
raise ImportError(message)


class XAttrFile:
Expand Down Expand Up @@ -48,21 +73,25 @@ class VFSStore:
def __init__(self, xattrfile: XAttrFile) -> None:
self.xattrfile = xattrfile

def _tensor_to_bytes(self, tensor: torch.Tensor) -> bytes:
def _tensor_to_bytes(self, tensor: "Tensor") -> bytes:
_require_torch()
assert torch is not None # for type checkers
buffer = io.BytesIO()
torch.save(tensor, buffer)
return buffer.getvalue()

def _bytes_to_tensor(self, b: bytes, map_location=None) -> torch.Tensor:

def _bytes_to_tensor(self, b: bytes, map_location=None) -> "Tensor":
_require_torch()
assert torch is not None # for type checkers
buffer = io.BytesIO(b)
return torch.load(buffer, map_location=map_location, weights_only=True)

def write_tensor(self, tensor: torch.Tensor) -> int:
def write_tensor(self, tensor: "Tensor") -> int:
btensor = self._tensor_to_bytes(tensor)
self.xattrfile.write("user.vectorvfs", btensor)
return len(btensor)

def read_tensor(self) -> torch.Tensor:
def read_tensor(self) -> "Tensor":
btensor = self.xattrfile.read("user.vectorvfs")
tensor = self._bytes_to_tensor(btensor)
return tensor