Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion converter/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def from_agent_farm(db_path: str, config: Optional[Dict] = None) -> AgentMemoryS

# For memory verification, we'll verify that the add_memory calls were successful
# by checking that we processed all the imported memories
logger.info(f"Verification: {len(all_memories)} memories were imported and processed")
logger.info(
f"Verification: {len(all_memories)} memories were imported and processed"
)

return memory_system

Expand Down
280 changes: 131 additions & 149 deletions main_demo.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion memory/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
import logging
from typing import Any, Dict, List, Optional, Union

from memory.space import MemorySpace
from memory.config import MemoryConfig
from memory.space import MemorySpace

logger = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions memory/embeddings/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,9 @@ def add(

# Store in Redis
key = f"{self.index_name}:{id}"
self.redis.hset(
self.redis.hset_dict(
key,
mapping={
{
self.vector_field: vector_bytes,
"metadata": metadata_json,
"timestamp": int(time.time()),
Expand Down
7 changes: 5 additions & 2 deletions memory/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, agent_id: str, config: MemoryConfig):

# Initialize vector store
self.vector_store = VectorStore(
redis_client=self.stm_store.redis,
redis_client=None,
stm_dimension=config.autoencoder_config.stm_dim,
im_dimension=config.autoencoder_config.im_dim,
ltm_dimension=config.autoencoder_config.ltm_dim,
Expand Down Expand Up @@ -293,7 +293,7 @@ def _create_memory_entry(
data = self.compression_engine.compress(data, level=2)

# Create standardized memory entry
return {
memory_entry = {
"memory_id": memory_id,
"agent_id": self.agent_id,
"step_number": step_number,
Expand All @@ -311,6 +311,9 @@ def _create_memory_entry(
"embeddings": embeddings,
}

self.vector_store.store_memory_vectors(memory_entry, tier)
return memory_entry

def _check_memory_transition(self) -> None:
"""Check if memories need to be transitioned between tiers.

Expand Down
31 changes: 24 additions & 7 deletions tests/converter/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest
from sqlalchemy.exc import SQLAlchemyError
import numpy as np

from converter.config import ConverterConfig
from converter.converter import from_agent_farm
Expand Down Expand Up @@ -121,6 +122,12 @@ def test_from_agent_farm_successful_import(
[mock_memory2], # One memory for agent 2
]

# Create a mock for SentenceTransformer
mock_sentence_transformer = MagicMock()
mock_sentence_transformer.__version__ = "2.2.2" # Add version attribute
mock_sentence_transformer.get_sentence_embedding_dimension.return_value = 384
mock_sentence_transformer.encode.return_value = np.array([0.1] * 384)

with patch(
"converter.converter.DatabaseManager", return_value=mock_db_manager
), patch(
Expand All @@ -129,7 +136,11 @@ def test_from_agent_farm_successful_import(
"converter.converter.MemoryImporter", return_value=mock_memory_importer
), patch(
"memory.core.AgentMemorySystem"
) as mock_memory_system:
) as mock_memory_system, patch(
"sentence_transformers.SentenceTransformer", return_value=mock_sentence_transformer
), patch(
"memory.embeddings.text_embeddings.SentenceTransformer", return_value=mock_sentence_transformer
):

# Configure mock memory system with two distinct agents
mock_memory_system.return_value.agents = {1: mock_agent1, 2: mock_agent2}
Expand Down Expand Up @@ -231,6 +242,12 @@ def test_from_agent_farm_import_verification(
[], # No memories for agent 1
]

# Create a mock for SentenceTransformer
mock_sentence_transformer = MagicMock()
mock_sentence_transformer.__version__ = "2.2.2"
mock_sentence_transformer.get_sentence_embedding_dimension.return_value = 384
mock_sentence_transformer.encode.return_value = np.array([0.1] * 384)

with patch(
"converter.converter.DatabaseManager", return_value=mock_db_manager
), patch(
Expand All @@ -239,16 +256,16 @@ def test_from_agent_farm_import_verification(
"converter.converter.MemoryImporter", return_value=mock_memory_importer
), patch(
"memory.core.AgentMemorySystem"
) as mock_memory_system:
) as mock_memory_system, patch(
"sentence_transformers.SentenceTransformer", return_value=mock_sentence_transformer
), patch(
"memory.embeddings.text_embeddings.SentenceTransformer", return_value=mock_sentence_transformer
):

# Mock memory system to simulate verification failure
# No agents in the system when we expect one
mock_memory_system.return_value.agents = {}

# Test should fail with agent count mismatch
with pytest.raises(ValueError) as exc_info:
with pytest.raises(ValueError, match="Import verification failed: agent count mismatch"):
from_agent_farm(str(db_path), config)

# Verify the error message contains agent count mismatch
error_msg = str(exc_info.value)
assert "Import verification failed: agent count mismatch" in error_msg
12 changes: 6 additions & 6 deletions tests/embeddings/test_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ def test_redis_vector_index_add(mock_redis_client):
result = index.add("test1", [0.1, 0.2, 0.3], {"name": "test"})

assert result is True
# Check that hset was called with the correct arguments
mock_redis_client.hset.assert_called_once()
args = mock_redis_client.hset.call_args[0]
# Check that hset_dict was called with the correct arguments
mock_redis_client.hset_dict.assert_called_once()
args = mock_redis_client.hset_dict.call_args[0]
assert args[0] == "test_index:test1"

# Check that mapping includes the vector field and metadata
mapping = mock_redis_client.hset.call_args[1]["mapping"]
mapping = mock_redis_client.hset_dict.call_args[0][1] # Get second positional argument
assert "embedding" in mapping
assert "metadata" in mapping
assert "timestamp" in mapping
Expand Down Expand Up @@ -370,8 +370,8 @@ def test_redis_vector_index_add_error(mock_redis_client):
"""Test error handling when adding vectors fails."""
index = RedisVectorIndex(mock_redis_client, "test_index")

# Make hset raise an exception
mock_redis_client.hset.side_effect = Exception("Test error")
# Make hset_dict raise an exception
mock_redis_client.hset_dict.side_effect = Exception("Test error")

result = index.add("test1", [0.1, 0.2, 0.3])
assert result is False
Expand Down
Loading
Loading