diff --git a/agents.py b/agents.py index 253f48a..15ed7ff 100644 --- a/agents.py +++ b/agents.py @@ -12,13 +12,9 @@ import numpy as np -from memory import ( - MemoryConfig, - MemorySpace, - RedisIMConfig, - RedisSTMConfig, - SQLiteLTMConfig, -) +from memory import (MemoryConfig, MemorySpace, RedisIMConfig, RedisSTMConfig, + SQLiteLTMConfig) +from memory.api.models import MazeActionSpace, MazeObservation from memory.utils.util import convert_numpy_to_python @@ -27,7 +23,7 @@ class SimpleAgent: def __init__( self, agent_id: str, - action_space: int = 4, + action_space: int | MazeActionSpace = 4, learning_rate: float = 0.1, discount_factor: float = 0.9, **kwargs, @@ -37,41 +33,48 @@ def __init__( Args: agent_id (str): Unique identifier for the agent. - action_space (int): Number of possible actions. + action_space (int or MazeActionSpace): Number of possible actions or MazeActionSpace object. learning_rate (float): Q-learning learning rate. discount_factor (float): Q-learning discount factor. - **kwargs: Additional arguments (unused). + **kwargs: Additional arguments. """ self.agent_id = agent_id - self.action_space = action_space + if isinstance(action_space, MazeActionSpace): + self.action_space = action_space.n + self.action_space_model = action_space + else: + self.action_space = action_space + self.action_space_model = MazeActionSpace(n=action_space) + self.learning_rate = learning_rate self.discount_factor = discount_factor - self.q_table = {} # State-action values + self.q_table = {} # State-action values #! fuzzier search than exact match self.current_observation = None self.demo_path = None # For scripted demo actions self.demo_step = 0 self.step_number = 0 - def _get_state_key(self, observation: dict) -> str: + for key, value in kwargs.items(): + setattr(self, key, value) + + def _get_state_key(self, observation: MazeObservation) -> str: """ Generate a unique key for a given observation/state. Args: - observation (dict): The environment observation. + observation (MazeObservation): The environment observation. Returns: str: A string key representing the state. """ - return ( - f"{observation['position']}|{observation['target']}|{observation['steps']}" - ) + return f"{observation.position}|{observation.target}|{observation.steps}" - def select_action(self, observation: dict, epsilon: float = 0.1) -> int: + def select_action(self, observation: MazeObservation, epsilon: float = 0.1) -> int: """ Select an action using an epsilon-greedy policy or a demonstration path. Args: - observation (dict): The current environment observation. + observation (MazeObservation): The current environment observation. epsilon (float): Probability of choosing a random action (exploration). Returns: @@ -98,20 +101,20 @@ def select_action(self, observation: dict, epsilon: float = 0.1) -> int: def update_q_value( self, - observation: dict, + observation: MazeObservation, action: int, reward: float, - next_observation: dict, + next_observation: MazeObservation, done: bool, ) -> None: """ Update the Q-value for a state-action pair using the Q-learning rule. Args: - observation (dict): The current state observation. + observation (MazeObservation): The current state observation. action (int): The action taken. reward (float): The reward received. - next_observation (dict): The next state observation. + next_observation (MazeObservation): The next state observation. done (bool): Whether the episode has ended. """ state_key = self._get_state_key(observation) @@ -134,12 +137,12 @@ def update_q_value( ) self.q_table[state_key][action] = new_q - def act(self, observation: dict, epsilon: float = 0.1) -> int: + def act(self, observation: MazeObservation, epsilon: float = 0.1) -> int: """ Choose and return an action for the given observation. Args: - observation (dict): The current environment observation. + observation (MazeObservation): The current environment observation. epsilon (float): Probability of choosing a random action (exploration). Returns: @@ -147,9 +150,12 @@ def act(self, observation: dict, epsilon: float = 0.1) -> int: """ self.step_number += 1 # Convert NumPy types to Python types - self.current_observation = convert_numpy_to_python(observation) + obs = convert_numpy_to_python(observation) + if not isinstance(obs, MazeObservation): + obs = MazeObservation(**obs) + self.current_observation = obs action = self.select_action(self.current_observation, epsilon) - return int(action) # Return as integer instead of ActionResult + return int(action) def set_demo_path(self, path: list[int]) -> None: """ @@ -167,7 +173,7 @@ class MemoryAgent(SimpleAgent): def __init__( self, agent_id: str, - action_space: int = 4, + action_space: int | MazeActionSpace = 4, learning_rate: float = 0.1, discount_factor: float = 0.9, **kwargs, @@ -177,7 +183,7 @@ def __init__( Args: agent_id (str): Unique identifier for the agent. - action_space (int): Number of possible actions. + action_space (int or MazeActionSpace): Number of possible actions or MazeActionSpace object. learning_rate (float): Q-learning learning rate. discount_factor (float): Q-learning discount factor. **kwargs: Additional arguments (unused). @@ -213,19 +219,19 @@ def __init__( text_model_name="all-MiniLM-L6-v2", # Use a default text embedding model ) # Store the memory system and get the memory space for this agent - self.memory_space = MemorySpace(agent_id, memory_config) + self.memory = MemorySpace(agent_id, memory_config) # Keep track of visited states to avoid redundant storage self.visited_states = set() # Add memory cache for direct position lookups self.position_memory_cache = {} # Mapping from positions to memories - def select_action(self, observation: dict, epsilon: float = 0.1) -> int: + def select_action(self, observation: MazeObservation, epsilon: float = 0.1) -> int: """ Select an action using memory-augmented Q-learning and experience recall. Args: - observation (dict): The current environment observation. + observation (MazeObservation): The current environment observation. epsilon (float): Probability of choosing a random action (exploration). Returns: @@ -233,7 +239,7 @@ def select_action(self, observation: dict, epsilon: float = 0.1) -> int: """ self.current_observation = observation state_key = self._get_state_key(observation) - position_key = str(observation["position"]) # Use position as direct lookup key + position_key = str(observation.position) # Use position as direct lookup key # Initialize state if not seen before if state_key not in self.q_table: @@ -251,18 +257,18 @@ def select_action(self, observation: dict, epsilon: float = 0.1) -> int: if state_key not in self.visited_states: # Enhanced state representation enhanced_state = { - "position": observation["position"], - "target": observation["target"], - "steps": observation["steps"], - "nearby_obstacles": observation["nearby_obstacles"], + "position": observation.position, + "target": observation.target, + "steps": observation.steps, + "nearby_obstacles": observation.nearby_obstacles, "manhattan_distance": abs( - observation["position"][0] - observation["target"][0] + observation.position[0] - observation.target[0] ) - + abs(observation["position"][1] - observation["target"][1]), + + abs(observation.position[1] - observation.target[1]), "state_key": state_key, "position_key": position_key, # Add position key for direct lookup } - self.memory_space.store_state( + self.memory.store_state( state_data=convert_numpy_to_python(enhanced_state), step_number=self.step_number, priority=0.7, # Medium priority for state @@ -271,17 +277,18 @@ def select_action(self, observation: dict, epsilon: float = 0.1) -> int: # Create a query with the enhanced state features query_state = { - "position": observation["position"], - "target": observation["target"], - "steps": observation["steps"], + "position": observation.position, + "target": observation.target, + "steps": observation.steps, "manhattan_distance": abs( - observation["position"][0] - observation["target"][0] + observation.position[0] - observation.target[0] ) - + abs(observation["position"][1] - observation["target"][1]), + + abs(observation.position[1] - observation.target[1]), } # Use search strategy directly - similar_states = self.memory_space.retrieve_similar_states( + #! Needs to be updated to use the SimilaritySearch strategy + similar_states = self.memory.retrieve_similar_states( query_state=query_state, k=10, # Increase from 5 to 10 to find more candidates memory_type="state", @@ -360,21 +367,24 @@ def act(self, observation: dict, epsilon: float = 0.1) -> int: """ self.step_number += 1 # Convert NumPy types to Python types - self.current_observation = convert_numpy_to_python(observation) + obs = convert_numpy_to_python(observation) + if not isinstance(obs, MazeObservation): + obs = MazeObservation(**obs) + self.current_observation = obs action = self.select_action(self.current_observation, epsilon) # Store the action using memory space try: # Include more context in the action data - position_key = str(observation["position"]) + position_key = str(observation.position) action_data = { "action": int(action), - "position": self.current_observation["position"], + "position": self.current_observation.position, "state_key": self._get_state_key(self.current_observation), - "steps": self.current_observation["steps"], + "steps": self.current_observation.steps, "position_key": position_key, } - self.memory_space.store_action( + self.memory.store_action( action_data=action_data, step_number=self.step_number, priority=0.6, # Medium priority @@ -397,20 +407,20 @@ def act(self, observation: dict, epsilon: float = 0.1) -> int: def update_q_value( self, - observation: dict, + observation: MazeObservation, action: int, reward: float, - next_observation: dict, + next_observation: MazeObservation, done: bool, ) -> None: """ Update the Q-value and store the reward and outcome in memory. Args: - observation (dict): The current state observation. + observation (MazeObservation): The current state observation. action (int): The action taken. reward (float): The reward received. - next_observation (dict): The next state observation. + next_observation (MazeObservation): The next state observation. done (bool): Whether the episode has ended. """ # First, call the parent method to update Q-values @@ -419,21 +429,21 @@ def update_q_value( # Then store the reward and outcome using memory space try: # Enhance interaction data with more context - position_key = str(observation["position"]) - next_position_key = str(next_observation["position"]) + position_key = str(observation.position) + next_position_key = str(next_observation.position) interaction_data = { "action": int(action), "reward": float(reward), - "next_state": convert_numpy_to_python(next_observation["position"]), + "next_state": convert_numpy_to_python(next_observation.position), "done": done, "state_key": self._get_state_key(observation), "next_state_key": self._get_state_key(next_observation), - "steps": observation["steps"], + "steps": observation.steps, "manhattan_distance": abs( - observation["position"][0] - observation["target"][0] + observation.position[0] - observation.target[0] ) - + abs(observation["position"][1] - observation["target"][1]), + + abs(observation.position[1] - observation.target[1]), "position_key": position_key, "next_position_key": next_position_key, } @@ -443,7 +453,7 @@ def update_q_value( if done and reward > 0: # Successful completion priority = 1.0 # Maximum priority - self.memory_space.store_interaction( + self.memory.store_interaction( interaction_data=interaction_data, step_number=self.step_number, priority=priority, @@ -475,6 +485,6 @@ def update_q_value( # Random agent that chooses actions randomly class RandomAgent(SimpleAgent): - def select_action(self, observation, epsilon=0.1): + def select_action(self, observation: MazeObservation, epsilon: float = 0.1) -> int: self.current_observation = observation return np.random.randint(self.action_space) diff --git a/agents/__init__.py b/agents/__init__.py new file mode 100644 index 0000000..8e83aa5 --- /dev/null +++ b/agents/__init__.py @@ -0,0 +1 @@ +from .base import Agent diff --git a/agents/algo_agent.py b/agents/algo_agent.py new file mode 100644 index 0000000..286952a --- /dev/null +++ b/agents/algo_agent.py @@ -0,0 +1,288 @@ +from collections import deque + +import numpy as np + +from agents import Agent +from memory.api.models import MazeActionSpace, MazeObservation +from memory.config.memory_config import ( + MemoryConfig, + RedisIMConfig, + RedisSTMConfig, + SQLiteLTMConfig, +) +from memory.space import MemorySpace +from memory.utils.util import convert_numpy_to_python + + +class AlgoAgent(Agent): + def __init__( + self, + agent_id: str, + action_space: int | MazeActionSpace = 4, + search_algo: str = "bfs", + **kwargs, + ): + if isinstance(action_space, MazeActionSpace): + self.action_space = action_space.n + self.action_space_model = action_space + else: + self.action_space = action_space + self.action_space_model = MazeActionSpace(n=action_space) + self.agent_id = agent_id + self.search_algo = search_algo + self.demo_path = None + self.demo_step = 0 + self.last_plan = [] + for key, value in kwargs.items(): + setattr(self, key, value) + + def act(self, observation: MazeObservation, epsilon: float = 0.1) -> int: + if self.demo_path is not None and self.demo_step < len(self.demo_path): + action = self.demo_path[self.demo_step] + self.demo_step += 1 + return int(action) + obs = convert_numpy_to_python(observation) + if not isinstance(obs, MazeObservation): + obs = MazeObservation(**obs) + # Plan path if needed + if not self.last_plan or obs.position != self.last_plan[0]: + self.last_plan = self._plan_path(obs) + if len(self.last_plan) < 2: + # No path or already at target + return np.random.randint(self.action_space) + # Determine action to move from current to next position + current = self.last_plan[0] + next_pos = self.last_plan[1] + action = self._get_action_from_positions(current, next_pos) + # Advance plan + self.last_plan = self.last_plan[1:] + return int(action) + + def set_demo_path(self, path: list[int]) -> None: + self.demo_path = path + self.demo_step = 0 + + def _plan_path(self, obs: MazeObservation): + if callable(self.search_algo): + return self.search_algo(obs) + if self.search_algo == "bfs": + return self._bfs(obs) + elif self.search_algo == "dfs": + return self._dfs(obs) + else: + # Default to BFS + return self._bfs(obs) + + def _bfs(self, obs: MazeObservation): + start = obs.position + target = obs.target + size = max(max(start), max(target)) + 2 # crude estimate + obstacles = set(obs.nearby_obstacles) + queue = deque() + queue.append((start, [start])) + visited = set() + while queue: + pos, path = queue.popleft() + if pos == target: + return path + if pos in visited: + continue + visited.add(pos) + for action, (dr, dc) in enumerate([(-1, 0), (0, 1), (1, 0), (0, -1)]): + new_pos = (pos[0] + dr, pos[1] + dc) + if new_pos in visited or new_pos in obstacles: + continue + if ( + new_pos[0] < 0 + or new_pos[1] < 0 + or new_pos[0] >= size + or new_pos[1] >= size + ): + continue + queue.append((new_pos, path + [new_pos])) + return [start] # No path found + + def _dfs(self, obs: MazeObservation): + start = obs.position + target = obs.target + size = max(max(start), max(target)) + 2 + obstacles = set(obs.nearby_obstacles) + stack = [(start, [start])] + visited = set() + while stack: + pos, path = stack.pop() + if pos == target: + return path + if pos in visited: + continue + visited.add(pos) + for action, (dr, dc) in enumerate([(-1, 0), (0, 1), (1, 0), (0, -1)]): + new_pos = (pos[0] + dr, pos[1] + dc) + if new_pos in visited or new_pos in obstacles: + continue + if ( + new_pos[0] < 0 + or new_pos[1] < 0 + or new_pos[0] >= size + or new_pos[1] >= size + ): + continue + stack.append((new_pos, path + [new_pos])) + return [start] + + def _get_action_from_positions(self, current, next_pos): + # Returns action index to move from current to next_pos + dr = next_pos[0] - current[0] + dc = next_pos[1] - current[1] + if dr == -1 and dc == 0: + return 0 # up + elif dr == 0 and dc == 1: + return 1 # right + elif dr == 1 and dc == 0: + return 2 # down + elif dr == 0 and dc == -1: + return 3 # left + else: + return np.random.randint(self.action_space) + + +class MemoryAlgoAgent(AlgoAgent): + def __init__( + self, + agent_id: str, + action_space: int | MazeActionSpace = 4, + search_algo: str = "bfs", + **kwargs, + ): + super().__init__( + agent_id=agent_id, + action_space=action_space, + search_algo=search_algo, + **kwargs, + ) + memory_config = MemoryConfig( + stm_config=RedisSTMConfig( + ttl=120, + memory_limit=500, + use_mock=True, + ), + im_config=RedisIMConfig( + ttl=240, + memory_limit=1000, + compression_level=0, + use_mock=True, + ), + ltm_config=SQLiteLTMConfig( + compression_level=0, + batch_size=20, + db_path="memory_demo.db", + ), + cleanup_interval=1000, + enable_memory_hooks=False, + use_embedding_engine=True, + text_model_name="all-MiniLM-L6-v2", + ) + self.memory = MemorySpace(agent_id, memory_config) + self.position_memory_cache = {} + self.step_number = 0 + + def act(self, observation: MazeObservation, epsilon: float = 0.1) -> int: + self.step_number += 1 + if self.demo_path is not None and self.demo_step < len(self.demo_path): + action = self.demo_path[self.demo_step] + self.demo_step += 1 + return int(action) + obs = convert_numpy_to_python(observation) + if not isinstance(obs, MazeObservation): + obs = MazeObservation(**obs) + position_key = str(obs.position) + state_key = f"{obs.position}|{obs.target}|{obs.steps}" + # Try to retrieve similar states from memory + try: + query_state = { + "position": obs.position, + "target": obs.target, + "steps": obs.steps, + "manhattan_distance": abs(obs.position[0] - obs.target[0]) + + abs(obs.position[1] - obs.target[1]), + } + similar_states = self.memory.retrieve_similar_states( + query_state=query_state, + k=10, + memory_type="state", + ) + if len(similar_states) == 0 and position_key in self.position_memory_cache: + similar_states = self.position_memory_cache[position_key] + for s in similar_states: + mem_position = None + if "position" in s.get("content", {}): + mem_position = str(s["content"]["position"]) + elif "next_state" in s.get("content", {}): + mem_position = str(s["content"]["next_state"]) + if mem_position: + if mem_position not in self.position_memory_cache: + self.position_memory_cache[mem_position] = [] + if s not in self.position_memory_cache[mem_position]: + self.position_memory_cache[mem_position].append(s) + if similar_states and np.random.random() > 0.2: + actions_from_memory = [] + for s in similar_states: + if "action" in s.get("content", {}): + reward = s["content"].get("reward", -1) + weight = 1 + if reward > -2: + weight = 3 + if reward > 0: + weight = 5 + for _ in range(weight): + actions_from_memory.append(s["content"]["action"]) + if actions_from_memory: + chosen_action = max( + set(actions_from_memory), key=actions_from_memory.count + ) + # Store action in memory + self._store_action_and_interaction( + obs, chosen_action, None, None, None + ) + return int(chosen_action) + except Exception: + pass + # Otherwise, use the search algorithm + if not self.last_plan or obs.position != self.last_plan[0]: + self.last_plan = self._plan_path(obs) + if len(self.last_plan) < 2: + action = np.random.randint(self.action_space) + else: + current = self.last_plan[0] + next_pos = self.last_plan[1] + action = self._get_action_from_positions(current, next_pos) + self.last_plan = self.last_plan[1:] + # Store action in memory + self._store_action_and_interaction(obs, action, None, None, None) + return int(action) + + def _store_action_and_interaction(self, obs, action, reward, next_obs, done): + try: + position_key = str(obs.position) + action_data = { + "action": int(action), + "position": obs.position, + "state_key": f"{obs.position}|{obs.target}|{obs.steps}", + "steps": obs.steps, + "position_key": position_key, + } + self.memory.store_action( + action_data=action_data, + step_number=self.step_number, + priority=0.6, + ) + if position_key not in self.position_memory_cache: + self.position_memory_cache[position_key] = [] + memory_entry = {"content": action_data, "step_number": self.step_number} + self.position_memory_cache[position_key].append(memory_entry) + except Exception: + pass + + def set_demo_path(self, path: list[int]) -> None: + self.demo_path = path + self.demo_step = 0 diff --git a/agents/base.py b/agents/base.py new file mode 100644 index 0000000..2221ad7 --- /dev/null +++ b/agents/base.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod + +from memory.api.models import MazeObservation + + +class Agent(ABC): + """ + Abstract base class for all agents. + """ + + def __init__(self, agent_id: str, action_space, **kwargs): + self.agent_id = agent_id + self.action_space = action_space + for key, value in kwargs.items(): + setattr(self, key, value) + + @abstractmethod + def act(self, observation: MazeObservation, epsilon: float = 0.1) -> int: + """ + Choose and return an action for the given observation. + """ + pass + + @abstractmethod + def set_demo_path(self, path: list[int]) -> None: + """ + Set a predetermined path of actions for demonstration or scripted exploration. + """ + pass diff --git a/agents/deep_q_agent.py b/agents/deep_q_agent.py new file mode 100644 index 0000000..c2a5bc0 --- /dev/null +++ b/agents/deep_q_agent.py @@ -0,0 +1,344 @@ +import random +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from collections import deque + +from agents import Agent +from memory.api.models import MazeActionSpace, MazeObservation +from memory.utils.util import convert_numpy_to_python +from memory.config.memory_config import MemoryConfig, RedisIMConfig, RedisSTMConfig, SQLiteLTMConfig +from memory.space import MemorySpace + + +def observation_to_tensor(obs: MazeObservation) -> torch.Tensor: + # Flatten observation: position (2), target (2), steps (1), obstacles (up to 8 nearby obstacles, each 2) + pos = list(obs.position) + tgt = list(obs.target) + steps = [obs.steps] + # Pad or truncate nearby_obstacles to 8 + obstacles = list(obs.nearby_obstacles)[:8] + flat_obs = pos + tgt + steps + for ob in obstacles: + flat_obs.extend(list(ob)) + # Pad if fewer than 8 obstacles + while len(flat_obs) < 2 + 2 + 1 + 8 * 2: + flat_obs.append(0) + return torch.tensor(flat_obs, dtype=torch.float32) + + +class DQN(nn.Module): + def __init__(self, input_dim, output_dim): + super().__init__() + self.net = nn.Sequential( + nn.Linear(input_dim, 64), + nn.ReLU(), + nn.Linear(64, 64), + nn.ReLU(), + nn.Linear(64, output_dim), + ) + + def forward(self, x): + return self.net(x) + + +class DeepQAgent(Agent): + def __init__( + self, + agent_id: str, + action_space: int | MazeActionSpace = 4, + learning_rate: float = 1e-3, + discount_factor: float = 0.99, + batch_size: int = 32, + memory_size: int = 10000, + target_update: int = 100, + device: str = None, + **kwargs, + ): + if isinstance(action_space, MazeActionSpace): + self.action_space = action_space.n + self.action_space_model = action_space + else: + self.action_space = action_space + self.action_space_model = MazeActionSpace(n=action_space) + self.agent_id = agent_id + self.learning_rate = learning_rate + self.discount_factor = discount_factor + self.batch_size = batch_size + self.memory = deque(maxlen=memory_size) + self.target_update = target_update + self.step_number = 0 + self.demo_path = None + self.demo_step = 0 + self.current_observation = None + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self.input_dim = 2 + 2 + 1 + 8 * 2 # pos(2), tgt(2), steps(1), obstacles(8x2) + self.policy_net = DQN(self.input_dim, self.action_space).to(self.device) + self.target_net = DQN(self.input_dim, self.action_space).to(self.device) + self.target_net.load_state_dict(self.policy_net.state_dict()) + self.target_net.eval() + self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate) + for key, value in kwargs.items(): + setattr(self, key, value) + + def act(self, observation: MazeObservation, epsilon: float = 0.1) -> int: + self.step_number += 1 + obs = convert_numpy_to_python(observation) + if not isinstance(obs, MazeObservation): + obs = MazeObservation(**obs) + self.current_observation = obs + if self.demo_path is not None and self.demo_step < len(self.demo_path): + action = self.demo_path[self.demo_step] + self.demo_step += 1 + return int(action) + if np.random.random() < epsilon: + return np.random.randint(self.action_space) + obs_tensor = observation_to_tensor(obs).unsqueeze(0).to(self.device) + with torch.no_grad(): + q_values = self.policy_net(obs_tensor) + return int(torch.argmax(q_values).item()) + + def set_demo_path(self, path: list[int]) -> None: + self.demo_path = path + self.demo_step = 0 + + def remember(self, obs, action, reward, next_obs, done): + self.memory.append((obs, action, reward, next_obs, done)) + + def update(self): + if len(self.memory) < self.batch_size: + return + batch = random.sample(self.memory, self.batch_size) + obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = zip(*batch) + obs_tensor = torch.stack([observation_to_tensor(o) for o in obs_batch]).to(self.device) + action_tensor = torch.tensor(action_batch, dtype=torch.long, device=self.device).unsqueeze(1) + reward_tensor = torch.tensor(reward_batch, dtype=torch.float32, device=self.device).unsqueeze(1) + next_obs_tensor = torch.stack([observation_to_tensor(o) for o in next_obs_batch]).to(self.device) + done_tensor = torch.tensor(done_batch, dtype=torch.float32, device=self.device).unsqueeze(1) + # Q(s,a) + q_values = self.policy_net(obs_tensor).gather(1, action_tensor) + # max_a' Q_target(s',a') + with torch.no_grad(): + next_q_values = self.target_net(next_obs_tensor).max(1, keepdim=True)[0] + target = reward_tensor + self.discount_factor * next_q_values * (1 - done_tensor) + loss = nn.functional.mse_loss(q_values, target) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + # Periodically update target network + if self.step_number % self.target_update == 0: + self.target_net.load_state_dict(self.policy_net.state_dict()) + + def observe(self, observation, action, reward, next_observation, done): + # Store experience and train + obs = convert_numpy_to_python(observation) + if not isinstance(obs, MazeObservation): + obs = MazeObservation(**obs) + next_obs = convert_numpy_to_python(next_observation) + if not isinstance(next_obs, MazeObservation): + next_obs = MazeObservation(**next_obs) + self.remember(obs, action, reward, next_obs, done) + self.update() + + +class MemoryDeepQAgent(DeepQAgent): + def __init__( + self, + agent_id: str, + action_space: int | MazeActionSpace = 4, + learning_rate: float = 1e-3, + discount_factor: float = 0.99, + batch_size: int = 32, + memory_size: int = 10000, + target_update: int = 100, + device: str = None, + **kwargs, + ): + super().__init__( + agent_id=agent_id, + action_space=action_space, + learning_rate=learning_rate, + discount_factor=discount_factor, + batch_size=batch_size, + memory_size=memory_size, + target_update=target_update, + device=device, + **kwargs, + ) + memory_config = MemoryConfig( + stm_config=RedisSTMConfig( + ttl=120, + memory_limit=500, + use_mock=True, + ), + im_config=RedisIMConfig( + ttl=240, + memory_limit=1000, + compression_level=0, + use_mock=True, + ), + ltm_config=SQLiteLTMConfig( + compression_level=0, + batch_size=20, + db_path="memory_demo.db", + ), + cleanup_interval=1000, + enable_memory_hooks=False, + use_embedding_engine=True, + text_model_name="all-MiniLM-L6-v2", + ) + self.memory = MemorySpace(agent_id, memory_config) + self.visited_states = set() + self.position_memory_cache = {} + + def select_action(self, observation: MazeObservation, epsilon: float = 0.1) -> int: + self.current_observation = observation + state_key = f"{observation.position}|{observation.target}|{observation.steps}" + position_key = str(observation.position) + if self.demo_path is not None and self.demo_step < len(self.demo_path): + action = self.demo_path[self.demo_step] + self.demo_step += 1 + return int(action) + try: + if state_key not in self.visited_states: + enhanced_state = { + "position": observation.position, + "target": observation.target, + "steps": observation.steps, + "nearby_obstacles": getattr(observation, "nearby_obstacles", None), + "manhattan_distance": abs(observation.position[0] - observation.target[0]) + abs(observation.position[1] - observation.target[1]), + "state_key": state_key, + "position_key": position_key, + } + self.memory.store_state( + state_data=convert_numpy_to_python(enhanced_state), + step_number=self.step_number, + priority=0.7, + ) + self.visited_states.add(state_key) + query_state = { + "position": observation.position, + "target": observation.target, + "steps": observation.steps, + "manhattan_distance": abs(observation.position[0] - observation.target[0]) + abs(observation.position[1] - observation.target[1]), + } + similar_states = self.memory.retrieve_similar_states( + query_state=query_state, + k=10, + memory_type="state", + ) + if len(similar_states) == 0: + if position_key in self.position_memory_cache: + similar_states = self.position_memory_cache[position_key] + for s in similar_states: + mem_position = None + if "position" in s.get("content", {}): + mem_position = str(s["content"]["position"]) + elif "next_state" in s.get("content", {}): + mem_position = str(s["content"]["next_state"]) + if mem_position: + if mem_position not in self.position_memory_cache: + self.position_memory_cache[mem_position] = [] + if s not in self.position_memory_cache[mem_position]: + self.position_memory_cache[mem_position].append(s) + if similar_states and np.random.random() > 0.2: + actions_from_memory = [] + for s in similar_states: + if "action" in s.get("content", {}): + reward = s["content"].get("reward", -1) + weight = 1 + if reward > -2: + weight = 3 + if reward > 0: + weight = 5 + for _ in range(weight): + actions_from_memory.append(s["content"]["action"]) + if actions_from_memory: + chosen_action = max(set(actions_from_memory), key=actions_from_memory.count) + return int(chosen_action) + except Exception: + pass + # Fallback to DQN policy or random + if np.random.random() < epsilon: + return np.random.randint(self.action_space) + obs_tensor = observation_to_tensor(observation).unsqueeze(0).to(self.device) + with torch.no_grad(): + q_values = self.policy_net(obs_tensor) + return int(torch.argmax(q_values).item()) + + def act(self, observation: MazeObservation, epsilon: float = 0.1) -> int: + self.step_number += 1 + obs = convert_numpy_to_python(observation) + if not isinstance(obs, MazeObservation): + obs = MazeObservation(**obs) + self.current_observation = obs + action = self.select_action(self.current_observation, epsilon) + try: + position_key = str(observation.position) + action_data = { + "action": int(action), + "position": self.current_observation.position, + "state_key": f"{self.current_observation.position}|{self.current_observation.target}|{self.current_observation.steps}", + "steps": self.current_observation.steps, + "position_key": position_key, + } + self.memory.store_action( + action_data=action_data, + step_number=self.step_number, + priority=0.6, + ) + if position_key not in self.position_memory_cache: + self.position_memory_cache[position_key] = [] + memory_entry = {"content": action_data, "step_number": self.step_number} + self.position_memory_cache[position_key].append(memory_entry) + except Exception: + pass + return int(action) + + def observe(self, observation, action, reward, next_observation, done): + # Store experience and train + obs = convert_numpy_to_python(observation) + if not isinstance(obs, MazeObservation): + obs = MazeObservation(**obs) + next_obs = convert_numpy_to_python(next_observation) + if not isinstance(next_obs, MazeObservation): + next_obs = MazeObservation(**next_obs) + self.remember(obs, action, reward, next_obs, done) + try: + position_key = str(observation.position) + next_position_key = str(next_observation.position) + interaction_data = { + "action": int(action), + "reward": float(reward), + "next_state": convert_numpy_to_python(next_observation.position), + "done": done, + "state_key": f"{observation.position}|{observation.target}|{observation.steps}", + "next_state_key": f"{next_observation.position}|{next_observation.target}|{next_observation.steps}", + "steps": observation.steps, + "manhattan_distance": abs(observation.position[0] - observation.target[0]) + abs(observation.position[1] - observation.target[1]), + "position_key": position_key, + "next_position_key": next_position_key, + } + priority = abs(float(reward)) / 100 + if done and reward > 0: + priority = 1.0 + self.memory.store_interaction( + interaction_data=interaction_data, + step_number=self.step_number, + priority=priority, + ) + for pos_key in [position_key, next_position_key]: + if pos_key not in self.position_memory_cache: + self.position_memory_cache[pos_key] = [] + memory_entry = { + "content": interaction_data, + "step_number": self.step_number, + } + self.position_memory_cache[pos_key].append(memory_entry) + if done and reward > 0: + for _ in range(10): + self.position_memory_cache[position_key].append(memory_entry) + except Exception: + pass + self.update() diff --git a/agents/q_agent.py b/agents/q_agent.py new file mode 100644 index 0000000..cffc8b5 --- /dev/null +++ b/agents/q_agent.py @@ -0,0 +1,364 @@ +import numpy as np + +from agents import Agent +from memory import ( + MemoryConfig, + MemorySpace, + RedisIMConfig, + RedisSTMConfig, + SQLiteLTMConfig, +) +from memory.api.models import MazeActionSpace, MazeObservation +from memory.utils.util import convert_numpy_to_python + + +class QAgent(Agent): + def __init__( + self, + agent_id: str, + action_space: int | MazeActionSpace = 4, + learning_rate: float = 0.1, + discount_factor: float = 0.9, + **kwargs, + ) -> None: + """ + Initialize a q-learning agent for reinforcement learning. + + Args: + agent_id (str): Unique identifier for the agent. + action_space (int or MazeActionSpace): Number of possible actions or MazeActionSpace object. + learning_rate (float): Q-learning learning rate. + discount_factor (float): Q-learning discount factor. + **kwargs: Additional arguments. + """ + self.agent_id = agent_id + if isinstance(action_space, MazeActionSpace): + self.action_space = action_space.n + self.action_space_model = action_space + else: + self.action_space = action_space + self.action_space_model = MazeActionSpace(n=action_space) + + self.learning_rate = learning_rate + self.discount_factor = discount_factor + self.q_table = {} # State-action values #! fuzzier search than exact match + self.current_observation = None + self.demo_path = None # For scripted demo actions + self.demo_step = 0 + self.step_number = 0 + + for key, value in kwargs.items(): + setattr(self, key, value) + + def _get_state_key(self, observation: MazeObservation) -> str: + """ + Generate a unique key for a given observation/state. + + Args: + observation (MazeObservation): The environment observation. + + Returns: + str: A string key representing the state. + """ + return f"{observation.position}|{observation.target}|{observation.steps}" + + def select_action(self, observation: MazeObservation, epsilon: float = 0.1) -> int: + """ + Select an action using an epsilon-greedy policy or a demonstration path. + + Args: + observation (MazeObservation): The current environment observation. + epsilon (float): Probability of choosing a random action (exploration). + + Returns: + int: The selected action index. + """ + self.current_observation = observation + state_key = self._get_state_key(observation) + + # Initialize state if not seen before + if state_key not in self.q_table: + self.q_table[state_key] = np.zeros(self.action_space) + + # If we have a demo path, follow it first to ensure we explore the correct path + if self.demo_path is not None and self.demo_step < len(self.demo_path): + action = self.demo_path[self.demo_step] + self.demo_step += 1 + return action + + # Epsilon-greedy policy + if np.random.random() < epsilon: + return np.random.randint(self.action_space) + else: + return np.argmax(self.q_table[state_key]) + + def update_q_value( + self, + observation: MazeObservation, + action: int, + reward: float, + next_observation: MazeObservation, + done: bool, + ) -> None: + """ + Update the Q-value for a state-action pair using the Q-learning rule. + + Args: + observation (MazeObservation): The current state observation. + action (int): The action taken. + reward (float): The reward received. + next_observation (MazeObservation): The next state observation. + done (bool): Whether the episode has ended. + """ + state_key = self._get_state_key(observation) + next_state_key = self._get_state_key(next_observation) + + # Initialize next state if not seen before + if next_state_key not in self.q_table: + self.q_table[next_state_key] = np.zeros(self.action_space) + + # Q-learning update + current_q = self.q_table[state_key][action] + + if done: + max_next_q = 0 + else: + max_next_q = np.max(self.q_table[next_state_key]) + + new_q = current_q + self.learning_rate * ( + reward + self.discount_factor * max_next_q - current_q + ) + self.q_table[state_key][action] = new_q + + def act(self, observation: MazeObservation, epsilon: float = 0.1) -> int: + """ + Choose and return an action for the given observation. + + Args: + observation (MazeObservation): The current environment observation. + epsilon (float): Probability of choosing a random action (exploration). + + Returns: + int: The selected action index. + """ + self.step_number += 1 + # Convert NumPy types to Python types + obs = convert_numpy_to_python(observation) + if not isinstance(obs, MazeObservation): + obs = MazeObservation(**obs) + self.current_observation = obs + action = self.select_action(self.current_observation, epsilon) + return int(action) + + def set_demo_path(self, path: list[int]) -> None: + """ + Set a predetermined path of actions for demonstration or scripted exploration. + + Args: + path (list[int]): List of action indices to follow. + """ + self.demo_path = path + self.demo_step = 0 + + +class MemoryQAgent(QAgent): + def __init__( + self, + agent_id: str, + action_space: int | MazeActionSpace = 4, + learning_rate: float = 0.1, + discount_factor: float = 0.9, + **kwargs, + ) -> None: + super().__init__( + agent_id=agent_id, + action_space=action_space, + learning_rate=learning_rate, + discount_factor=discount_factor, + **kwargs, + ) + memory_config = MemoryConfig( + stm_config=RedisSTMConfig( + ttl=120, + memory_limit=500, + use_mock=True, + ), + im_config=RedisIMConfig( + ttl=240, + memory_limit=1000, + compression_level=0, + use_mock=True, + ), + ltm_config=SQLiteLTMConfig( + compression_level=0, + batch_size=20, + db_path="memory_demo.db", + ), + cleanup_interval=1000, + enable_memory_hooks=False, + use_embedding_engine=True, + text_model_name="all-MiniLM-L6-v2", + ) + self.memory = MemorySpace(agent_id, memory_config) + self.visited_states = set() + self.position_memory_cache = {} + + def select_action(self, observation: MazeObservation, epsilon: float = 0.1) -> int: + self.current_observation = observation + state_key = self._get_state_key(observation) + position_key = str(observation.position) + if state_key not in self.q_table: + self.q_table[state_key] = np.zeros(self.action_space) + if self.demo_path is not None and self.demo_step < len(self.demo_path): + action = self.demo_path[self.demo_step] + self.demo_step += 1 + return action + try: + if state_key not in self.visited_states: + enhanced_state = { + "position": observation.position, + "target": observation.target, + "steps": observation.steps, + "nearby_obstacles": getattr(observation, "nearby_obstacles", None), + "manhattan_distance": abs( + observation.position[0] - observation.target[0] + ) + + abs(observation.position[1] - observation.target[1]), + "state_key": state_key, + "position_key": position_key, + } + self.memory.store_state( + state_data=convert_numpy_to_python(enhanced_state), + step_number=self.step_number, + priority=0.7, + ) + self.visited_states.add(state_key) + query_state = { + "position": observation.position, + "target": observation.target, + "steps": observation.steps, + "manhattan_distance": abs( + observation.position[0] - observation.target[0] + ) + + abs(observation.position[1] - observation.target[1]), + } + similar_states = self.memory.retrieve_similar_states( + query_state=query_state, + k=10, + memory_type="state", + ) + if len(similar_states) == 0: + if position_key in self.position_memory_cache: + similar_states = self.position_memory_cache[position_key] + for s in similar_states: + mem_position = None + if "position" in s.get("content", {}): + mem_position = str(s["content"]["position"]) + elif "next_state" in s.get("content", {}): + mem_position = str(s["content"]["next_state"]) + if mem_position: + if mem_position not in self.position_memory_cache: + self.position_memory_cache[mem_position] = [] + if s not in self.position_memory_cache[mem_position]: + self.position_memory_cache[mem_position].append(s) + if similar_states and np.random.random() > 0.2: + actions_from_memory = [] + for s in similar_states: + if "action" in s.get("content", {}): + reward = s["content"].get("reward", -1) + weight = 1 + if reward > -2: + weight = 3 + if reward > 0: + weight = 5 + for _ in range(weight): + actions_from_memory.append(s["content"]["action"]) + if actions_from_memory: + chosen_action = max( + set(actions_from_memory), key=actions_from_memory.count + ) + return chosen_action + except Exception: + pass + if np.random.random() < epsilon: + return np.random.randint(self.action_space) + else: + return int(np.argmax(self.q_table[state_key])) + + def act(self, observation: MazeObservation, epsilon: float = 0.1) -> int: + self.step_number += 1 + obs = convert_numpy_to_python(observation) + if not isinstance(obs, MazeObservation): + obs = MazeObservation(**obs) + self.current_observation = obs + action = self.select_action(self.current_observation, epsilon) + try: + position_key = str(observation.position) + action_data = { + "action": int(action), + "position": self.current_observation.position, + "state_key": self._get_state_key(self.current_observation), + "steps": self.current_observation.steps, + "position_key": position_key, + } + self.memory.store_action( + action_data=action_data, + step_number=self.step_number, + priority=0.6, + ) + if position_key not in self.position_memory_cache: + self.position_memory_cache[position_key] = [] + memory_entry = {"content": action_data, "step_number": self.step_number} + self.position_memory_cache[position_key].append(memory_entry) + except Exception: + pass + return int(action) + + def update_q_value( + self, + observation: MazeObservation, + action: int, + reward: float, + next_observation: MazeObservation, + done: bool, + ) -> None: + super().update_q_value(observation, action, reward, next_observation, done) + try: + position_key = str(observation.position) + next_position_key = str(next_observation.position) + interaction_data = { + "action": int(action), + "reward": float(reward), + "next_state": convert_numpy_to_python(next_observation.position), + "done": done, + "state_key": self._get_state_key(observation), + "next_state_key": self._get_state_key(next_observation), + "steps": observation.steps, + "manhattan_distance": abs( + observation.position[0] - observation.target[0] + ) + + abs(observation.position[1] - observation.target[1]), + "position_key": position_key, + "next_position_key": next_position_key, + } + priority = abs(float(reward)) / 100 + if done and reward > 0: + priority = 1.0 + self.memory.store_interaction( + interaction_data=interaction_data, + step_number=self.step_number, + priority=priority, + ) + for pos_key in [position_key, next_position_key]: + if pos_key not in self.position_memory_cache: + self.position_memory_cache[pos_key] = [] + memory_entry = { + "content": interaction_data, + "step_number": self.step_number, + } + self.position_memory_cache[pos_key].append(memory_entry) + if done and reward > 0: + for _ in range(10): + self.position_memory_cache[position_key].append(memory_entry) + except Exception: + pass diff --git a/agents/random_agent.py b/agents/random_agent.py new file mode 100644 index 0000000..56a8026 --- /dev/null +++ b/agents/random_agent.py @@ -0,0 +1,135 @@ +import numpy as np + +from agents import Agent +from memory.api.models import MazeActionSpace, MazeObservation +from memory.config.memory_config import ( + MemoryConfig, + RedisIMConfig, + RedisSTMConfig, + SQLiteLTMConfig, +) +from memory.space import MemorySpace +from memory.utils.util import convert_numpy_to_python + + +class RandomAgent(Agent): + def __init__( + self, agent_id: str, action_space: int | MazeActionSpace = 4, **kwargs + ): + if isinstance(action_space, MazeActionSpace): + self.action_space = action_space.n + self.action_space_model = action_space + else: + self.action_space = action_space + self.action_space_model = MazeActionSpace(n=action_space) + self.agent_id = agent_id + self.demo_path = None + self.demo_step = 0 + for key, value in kwargs.items(): + setattr(self, key, value) + + def act(self, observation: MazeObservation, epsilon: float = 0.1) -> int: + if self.demo_path is not None and self.demo_step < len(self.demo_path): + action = self.demo_path[self.demo_step] + self.demo_step += 1 + return int(action) + return int(np.random.randint(self.action_space)) + + def set_demo_path(self, path: list[int]) -> None: + self.demo_path = path + self.demo_step = 0 + + +class MemoryRandomAgent(RandomAgent): + def __init__( + self, agent_id: str, action_space: int | MazeActionSpace = 4, **kwargs + ): + super().__init__(agent_id, action_space, **kwargs) + memory_config = MemoryConfig( + stm_config=RedisSTMConfig( + ttl=120, + memory_limit=500, + use_mock=True, + ), + im_config=RedisIMConfig( + ttl=240, + memory_limit=1000, + compression_level=0, + use_mock=True, + ), + ltm_config=SQLiteLTMConfig( + compression_level=0, + batch_size=20, + db_path="memory_demo.db", + ), + cleanup_interval=1000, + enable_memory_hooks=False, + use_embedding_engine=True, + text_model_name="all-MiniLM-L6-v2", + ) + self.memory = MemorySpace(agent_id, memory_config) + self.position_memory_cache = {} + self.step_number = 0 + + def act(self, observation: MazeObservation, epsilon: float = 0.1) -> int: + self.step_number += 1 + if self.demo_path is not None and self.demo_step < len(self.demo_path): + action = self.demo_path[self.demo_step] + self.demo_step += 1 + return int(action) + position_key = str(observation.position) + state_key = f"{observation.position}|{observation.target}|{observation.steps}" + # Try to retrieve similar states from memory + try: + query_state = { + "position": observation.position, + "target": observation.target, + "steps": observation.steps, + "manhattan_distance": abs( + observation.position[0] - observation.target[0] + ) + + abs(observation.position[1] - observation.target[1]), + } + similar_states = self.memory.retrieve_similar_states( + query_state=query_state, + k=10, + memory_type="state", + ) + if len(similar_states) == 0 and position_key in self.position_memory_cache: + similar_states = self.position_memory_cache[position_key] + for s in similar_states: + mem_position = None + if "position" in s.get("content", {}): + mem_position = str(s["content"]["position"]) + elif "next_state" in s.get("content", {}): + mem_position = str(s["content"]["next_state"]) + if mem_position: + if mem_position not in self.position_memory_cache: + self.position_memory_cache[mem_position] = [] + if s not in self.position_memory_cache[mem_position]: + self.position_memory_cache[mem_position].append(s) + if similar_states and np.random.random() > 0.2: + actions_from_memory = [] + for s in similar_states: + if "action" in s.get("content", {}): + reward = s["content"].get("reward", -1) + weight = 1 + if reward > -2: + weight = 3 + if reward > 0: + weight = 5 + for _ in range(weight): + actions_from_memory.append(s["content"]["action"]) + if actions_from_memory: + chosen_action = max( + set(actions_from_memory), key=actions_from_memory.count + ) + return int(chosen_action) + except Exception: + pass + # Otherwise, act randomly + return int(np.random.randint(self.action_space)) + + def set_demo_path(self, path: list[int]) -> None: + self.demo_path = path + self.demo_step = 0 diff --git a/main_demo.py b/main_demo.py index d4ec0ba..df1afe1 100644 --- a/main_demo.py +++ b/main_demo.py @@ -50,7 +50,7 @@ def run_experiment(episodes=100, memory_enabled=True, random_seed=None): # Create the optimal path for demonstration #! Why do I need this? - optimal_path = create_optimal_path_for_maze(maze_size) + # optimal_path = create_optimal_path_for_maze(maze_size) # Create agent based on memory flag agent_id = "agent_memory" if memory_enabled else "standard_agent" @@ -99,13 +99,13 @@ def run_experiment(episodes=100, memory_enabled=True, random_seed=None): agent = MemoryAgent(agent_id, memory_system, action_space=4) # Set the demonstration path for the first episode - agent.set_demo_path(optimal_path) + # agent.set_demo_path(optimal_path) print("Created memory agent with text embedding engine (no autoencoder)") else: agent = SimpleAgent(agent_id, action_space=4) # No memory, but still give the demo path for the first episode - agent.set_demo_path(optimal_path) + # agent.set_demo_path(optimal_path) # Track metrics rewards_per_episode = [] diff --git a/maze.py b/maze.py index 0a35acf..5c5701b 100644 --- a/maze.py +++ b/maze.py @@ -16,6 +16,7 @@ obs, reward, done = env.step(1) # Take action 'right' """ +from memory.api.models import MazeObservation, MazeActionSpace class MazeEnvironment: """ @@ -28,6 +29,7 @@ class MazeEnvironment: max_steps (int): Maximum steps per episode. position (tuple[int, int]): Current agent position. steps (int): Steps taken in current episode. + action_space (MazeActionSpace): The action space for the environment. """ def __init__( @@ -48,6 +50,7 @@ def __init__( self.obstacles = obstacles or [] self.target = (size - 2, size - 2) self.max_steps = max_steps + self.action_space = MazeActionSpace() self.reset() def reset(self) -> dict: @@ -61,19 +64,19 @@ def reset(self) -> dict: self.steps = 0 return self.get_observation() - def get_observation(self) -> dict: + def get_observation(self) -> MazeObservation: """ Get the current observation of the environment. Returns: - dict: Observation containing position, target, nearby obstacles, and steps. + MazeObservation: Observation containing position, target, nearby obstacles, and steps. """ - return { - "position": self.position, - "target": self.target, - "nearby_obstacles": self._get_nearby_obstacles(), - "steps": self.steps, - } + return MazeObservation( + position=self.position, + target=self.target, + nearby_obstacles=self._get_nearby_obstacles(), + steps=self.steps, + ) def _get_nearby_obstacles(self) -> list[tuple[int, int]]: """ @@ -135,3 +138,12 @@ def step(self, action: int) -> tuple[dict, float, bool]: done = False return self.get_observation(), reward, done + + def get_action_space(self) -> MazeActionSpace: + """ + Get the action space for the environment. + + Returns: + MazeActionSpace: The action space model for the maze. + """ + return self.action_space diff --git a/memory/api/models.py b/memory/api/models.py index dd21365..5637e2d 100644 --- a/memory/api/models.py +++ b/memory/api/models.py @@ -158,3 +158,29 @@ class ActionResult(BaseModel): action_type: str params: Dict[str, Any] = Field(default_factory=dict) reward: float = 0.0 + + +class MazeObservation(BaseModel): + """Observation model for MazeEnvironment. + + Attributes: + position: The agent's current position (row, col). + target: The target position (row, col). + nearby_obstacles: List of nearby obstacle coordinates (row, col). + steps: Number of steps taken in the current episode. + """ + position: tuple[int, int] + target: tuple[int, int] + nearby_obstacles: list[tuple[int, int]] + steps: int + + +class MazeActionSpace(BaseModel): + """Action space model for MazeEnvironment. + + Attributes: + n: Number of discrete actions. + actions: List of action descriptions (e.g., ["up", "right", "down", "left"]). + """ + n: int = 4 + actions: list[str] = ["up", "right", "down", "left"] diff --git a/memory/space.py b/memory/space.py index f89ff81..8e0421d 100644 --- a/memory/space.py +++ b/memory/space.py @@ -523,6 +523,7 @@ def retrieve_similar_states( threshold: float = 0.6, context_weights: Dict[str, float] = None, ) -> List[Dict[str, Any]]: + #! Needs to be updated to use the SimilaritySearch strategy """Retrieve most similar past states to the provided query state. Args: