Skip to content
Merged

Dev #163

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 73 additions & 63 deletions agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,9 @@

import numpy as np

from memory import (
MemoryConfig,
MemorySpace,
RedisIMConfig,
RedisSTMConfig,
SQLiteLTMConfig,
)
from memory import (MemoryConfig, MemorySpace, RedisIMConfig, RedisSTMConfig,
SQLiteLTMConfig)
from memory.api.models import MazeActionSpace, MazeObservation
from memory.utils.util import convert_numpy_to_python


Expand All @@ -27,7 +23,7 @@ class SimpleAgent:
def __init__(
self,
agent_id: str,
action_space: int = 4,
action_space: int | MazeActionSpace = 4,
learning_rate: float = 0.1,
discount_factor: float = 0.9,
**kwargs,
Expand All @@ -37,41 +33,48 @@ def __init__(

Args:
agent_id (str): Unique identifier for the agent.
action_space (int): Number of possible actions.
action_space (int or MazeActionSpace): Number of possible actions or MazeActionSpace object.
learning_rate (float): Q-learning learning rate.
discount_factor (float): Q-learning discount factor.
**kwargs: Additional arguments (unused).
**kwargs: Additional arguments.
"""
self.agent_id = agent_id
self.action_space = action_space
if isinstance(action_space, MazeActionSpace):
self.action_space = action_space.n
self.action_space_model = action_space
else:
self.action_space = action_space
self.action_space_model = MazeActionSpace(n=action_space)

self.learning_rate = learning_rate
self.discount_factor = discount_factor
self.q_table = {} # State-action values
self.q_table = {} # State-action values #! fuzzier search than exact match
self.current_observation = None
self.demo_path = None # For scripted demo actions
self.demo_step = 0
self.step_number = 0

def _get_state_key(self, observation: dict) -> str:
for key, value in kwargs.items():
setattr(self, key, value)

def _get_state_key(self, observation: MazeObservation) -> str:
"""
Generate a unique key for a given observation/state.

Args:
observation (dict): The environment observation.
observation (MazeObservation): The environment observation.

Returns:
str: A string key representing the state.
"""
return (
f"{observation['position']}|{observation['target']}|{observation['steps']}"
)
return f"{observation.position}|{observation.target}|{observation.steps}"

def select_action(self, observation: dict, epsilon: float = 0.1) -> int:
def select_action(self, observation: MazeObservation, epsilon: float = 0.1) -> int:
"""
Select an action using an epsilon-greedy policy or a demonstration path.

Args:
observation (dict): The current environment observation.
observation (MazeObservation): The current environment observation.
epsilon (float): Probability of choosing a random action (exploration).

Returns:
Expand All @@ -98,20 +101,20 @@ def select_action(self, observation: dict, epsilon: float = 0.1) -> int:

def update_q_value(
self,
observation: dict,
observation: MazeObservation,
action: int,
reward: float,
next_observation: dict,
next_observation: MazeObservation,
done: bool,
) -> None:
"""
Update the Q-value for a state-action pair using the Q-learning rule.

Args:
observation (dict): The current state observation.
observation (MazeObservation): The current state observation.
action (int): The action taken.
reward (float): The reward received.
next_observation (dict): The next state observation.
next_observation (MazeObservation): The next state observation.
done (bool): Whether the episode has ended.
"""
state_key = self._get_state_key(observation)
Expand All @@ -134,22 +137,25 @@ def update_q_value(
)
self.q_table[state_key][action] = new_q

def act(self, observation: dict, epsilon: float = 0.1) -> int:
def act(self, observation: MazeObservation, epsilon: float = 0.1) -> int:
"""
Choose and return an action for the given observation.

Args:
observation (dict): The current environment observation.
observation (MazeObservation): The current environment observation.
epsilon (float): Probability of choosing a random action (exploration).

Returns:
int: The selected action index.
"""
self.step_number += 1
# Convert NumPy types to Python types
self.current_observation = convert_numpy_to_python(observation)
obs = convert_numpy_to_python(observation)
if not isinstance(obs, MazeObservation):
obs = MazeObservation(**obs)
self.current_observation = obs
action = self.select_action(self.current_observation, epsilon)
return int(action) # Return as integer instead of ActionResult
return int(action)

def set_demo_path(self, path: list[int]) -> None:
"""
Expand All @@ -167,7 +173,7 @@ class MemoryAgent(SimpleAgent):
def __init__(
self,
agent_id: str,
action_space: int = 4,
action_space: int | MazeActionSpace = 4,
learning_rate: float = 0.1,
discount_factor: float = 0.9,
**kwargs,
Expand All @@ -177,7 +183,7 @@ def __init__(

Args:
agent_id (str): Unique identifier for the agent.
action_space (int): Number of possible actions.
action_space (int or MazeActionSpace): Number of possible actions or MazeActionSpace object.
learning_rate (float): Q-learning learning rate.
discount_factor (float): Q-learning discount factor.
**kwargs: Additional arguments (unused).
Expand Down Expand Up @@ -213,27 +219,27 @@ def __init__(
text_model_name="all-MiniLM-L6-v2", # Use a default text embedding model
)
# Store the memory system and get the memory space for this agent
self.memory_space = MemorySpace(agent_id, memory_config)
self.memory = MemorySpace(agent_id, memory_config)

# Keep track of visited states to avoid redundant storage
self.visited_states = set()
# Add memory cache for direct position lookups
self.position_memory_cache = {} # Mapping from positions to memories

def select_action(self, observation: dict, epsilon: float = 0.1) -> int:
def select_action(self, observation: MazeObservation, epsilon: float = 0.1) -> int:
"""
Select an action using memory-augmented Q-learning and experience recall.

Args:
observation (dict): The current environment observation.
observation (MazeObservation): The current environment observation.
epsilon (float): Probability of choosing a random action (exploration).

Returns:
int: The selected action index.
"""
self.current_observation = observation
state_key = self._get_state_key(observation)
position_key = str(observation["position"]) # Use position as direct lookup key
position_key = str(observation.position) # Use position as direct lookup key

# Initialize state if not seen before
if state_key not in self.q_table:
Expand All @@ -251,18 +257,18 @@ def select_action(self, observation: dict, epsilon: float = 0.1) -> int:
if state_key not in self.visited_states:
# Enhanced state representation
enhanced_state = {
"position": observation["position"],
"target": observation["target"],
"steps": observation["steps"],
"nearby_obstacles": observation["nearby_obstacles"],
"position": observation.position,
"target": observation.target,
"steps": observation.steps,
"nearby_obstacles": observation.nearby_obstacles,
"manhattan_distance": abs(
observation["position"][0] - observation["target"][0]
observation.position[0] - observation.target[0]
)
+ abs(observation["position"][1] - observation["target"][1]),
+ abs(observation.position[1] - observation.target[1]),
"state_key": state_key,
"position_key": position_key, # Add position key for direct lookup
}
self.memory_space.store_state(
self.memory.store_state(
state_data=convert_numpy_to_python(enhanced_state),
step_number=self.step_number,
priority=0.7, # Medium priority for state
Expand All @@ -271,17 +277,18 @@ def select_action(self, observation: dict, epsilon: float = 0.1) -> int:

# Create a query with the enhanced state features
query_state = {
"position": observation["position"],
"target": observation["target"],
"steps": observation["steps"],
"position": observation.position,
"target": observation.target,
"steps": observation.steps,
"manhattan_distance": abs(
observation["position"][0] - observation["target"][0]
observation.position[0] - observation.target[0]
)
+ abs(observation["position"][1] - observation["target"][1]),
+ abs(observation.position[1] - observation.target[1]),
}

# Use search strategy directly
similar_states = self.memory_space.retrieve_similar_states(
#! Needs to be updated to use the SimilaritySearch strategy
similar_states = self.memory.retrieve_similar_states(
query_state=query_state,
k=10, # Increase from 5 to 10 to find more candidates
memory_type="state",
Expand Down Expand Up @@ -360,21 +367,24 @@ def act(self, observation: dict, epsilon: float = 0.1) -> int:
"""
self.step_number += 1
# Convert NumPy types to Python types
self.current_observation = convert_numpy_to_python(observation)
obs = convert_numpy_to_python(observation)
if not isinstance(obs, MazeObservation):
obs = MazeObservation(**obs)
self.current_observation = obs
action = self.select_action(self.current_observation, epsilon)

# Store the action using memory space
try:
# Include more context in the action data
position_key = str(observation["position"])
position_key = str(observation.position)
action_data = {
"action": int(action),
"position": self.current_observation["position"],
"position": self.current_observation.position,
"state_key": self._get_state_key(self.current_observation),
"steps": self.current_observation["steps"],
"steps": self.current_observation.steps,
"position_key": position_key,
}
self.memory_space.store_action(
self.memory.store_action(
action_data=action_data,
step_number=self.step_number,
priority=0.6, # Medium priority
Expand All @@ -397,20 +407,20 @@ def act(self, observation: dict, epsilon: float = 0.1) -> int:

def update_q_value(
self,
observation: dict,
observation: MazeObservation,
action: int,
reward: float,
next_observation: dict,
next_observation: MazeObservation,
done: bool,
) -> None:
"""
Update the Q-value and store the reward and outcome in memory.

Args:
observation (dict): The current state observation.
observation (MazeObservation): The current state observation.
action (int): The action taken.
reward (float): The reward received.
next_observation (dict): The next state observation.
next_observation (MazeObservation): The next state observation.
done (bool): Whether the episode has ended.
"""
# First, call the parent method to update Q-values
Expand All @@ -419,21 +429,21 @@ def update_q_value(
# Then store the reward and outcome using memory space
try:
# Enhance interaction data with more context
position_key = str(observation["position"])
next_position_key = str(next_observation["position"])
position_key = str(observation.position)
next_position_key = str(next_observation.position)

interaction_data = {
"action": int(action),
"reward": float(reward),
"next_state": convert_numpy_to_python(next_observation["position"]),
"next_state": convert_numpy_to_python(next_observation.position),
"done": done,
"state_key": self._get_state_key(observation),
"next_state_key": self._get_state_key(next_observation),
"steps": observation["steps"],
"steps": observation.steps,
"manhattan_distance": abs(
observation["position"][0] - observation["target"][0]
observation.position[0] - observation.target[0]
)
+ abs(observation["position"][1] - observation["target"][1]),
+ abs(observation.position[1] - observation.target[1]),
"position_key": position_key,
"next_position_key": next_position_key,
}
Expand All @@ -443,7 +453,7 @@ def update_q_value(
if done and reward > 0: # Successful completion
priority = 1.0 # Maximum priority

self.memory_space.store_interaction(
self.memory.store_interaction(
interaction_data=interaction_data,
step_number=self.step_number,
priority=priority,
Expand Down Expand Up @@ -475,6 +485,6 @@ def update_q_value(

# Random agent that chooses actions randomly
class RandomAgent(SimpleAgent):
def select_action(self, observation, epsilon=0.1):
def select_action(self, observation: MazeObservation, epsilon: float = 0.1) -> int:
self.current_observation = observation
return np.random.randint(self.action_space)
1 change: 1 addition & 0 deletions agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .base import Agent
Loading
Loading