-
Notifications
You must be signed in to change notification settings - Fork 0
Add Q-learning state initialization and new README documentation #164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| # Agents Module | ||
|
|
||
| This module provides a variety of agent classes for use in reinforcement learning and maze navigation environments. Agents can be used as-is or extended for custom behaviors. Many agents have both standard and memory-augmented variants that leverage episodic and semantic memory for improved performance. | ||
|
|
||
| ## Agent Types | ||
|
|
||
| ### 1. `Agent` (Abstract Base Class) | ||
| Defines the interface for all agents. To implement a custom agent, inherit from this class and implement the required methods. | ||
|
|
||
| **API:** | ||
| ```python | ||
| class Agent(ABC): | ||
| def __init__(self, agent_id: str, action_space, **kwargs): ... | ||
| @abstractmethod | ||
| def act(self, observation: MazeObservation, epsilon: float = 0.1) -> int: ... | ||
| @abstractmethod | ||
| def set_demo_path(self, path: list[int]) -> None: ... | ||
| ``` | ||
|
|
||
| ### 2. `RandomAgent` | ||
| Selects actions randomly from the action space. Useful as a baseline. | ||
|
|
||
| ### 3. `MemoryRandomAgent` | ||
| A random agent that also stores and retrieves state/action information from a memory system, biasing action selection toward previously successful actions. | ||
|
|
||
| ### 4. `AlgoAgent` | ||
| A planning agent that uses search algorithms (BFS/DFS or custom) to plan a path to the target. Good for deterministic environments. | ||
|
|
||
| ### 5. `MemoryAlgoAgent` | ||
| A planning agent with memory augmentation. Retrieves similar states from memory to bias planning and action selection. | ||
|
|
||
| ### 6. `QAgent` | ||
| Implements tabular Q-learning. Maintains a Q-table for state-action values and uses an epsilon-greedy policy. | ||
|
|
||
| ### 7. `MemoryQAgent` | ||
| A Q-learning agent with memory augmentation. Stores and retrieves states, actions, and interactions from memory to bias exploration and exploitation. | ||
|
|
||
| ### 8. `DeepQAgent` | ||
| Implements Deep Q-Learning using PyTorch. Uses a neural network to approximate Q-values and experience replay for training. | ||
|
|
||
| ### 9. `MemoryDeepQAgent` | ||
| A deep Q-learning agent with memory augmentation. Stores and retrieves states and interactions from memory to bias action selection and learning. | ||
|
|
||
| --- | ||
|
|
||
| ## Usage | ||
|
|
||
| > **Note:** Only the abstract `Agent` is exposed in `agents/__init__.py`. To use concrete agents, import them directly from their respective files: | ||
|
|
||
| ```python | ||
| from agents.random_agent import RandomAgent, MemoryRandomAgent | ||
| from agents.algo_agent import AlgoAgent, MemoryAlgoAgent | ||
| from agents.q_agent import QAgent, MemoryQAgent | ||
| from agents.deep_q_agent import DeepQAgent, MemoryDeepQAgent | ||
| ``` | ||
|
|
||
| ## Example | ||
|
|
||
| ```python | ||
| from agents.q_agent import QAgent | ||
| from memory.api.models import MazeObservation | ||
|
|
||
| agent = QAgent(agent_id="A1", action_space=4) | ||
| obs = MazeObservation(position=(0,0), target=(3,3), steps=0, nearby_obstacles=[]) | ||
| action = agent.act(obs) | ||
| ``` | ||
|
|
||
| ## Extending Agents | ||
| To create your own agent, inherit from `Agent` and implement the `act` and `set_demo_path` methods. | ||
|
|
||
| ## Memory-Augmented Agents | ||
| Memory-augmented agents use a `MemorySpace` object to store and retrieve states, actions, and interactions. This enables: | ||
| - Retrieval of similar past states for biasing action selection | ||
| - Storing successful actions/interactions for future use | ||
| - Episodic and semantic memory integration | ||
|
|
||
| ## Requirements | ||
| - `memory` module (for memory-augmented agents) | ||
| - `numpy`, `torch` (for DeepQAgent) | ||
|
|
||
| --- | ||
|
|
||
| ## File Overview | ||
| - `base.py`: Abstract base class | ||
| - `random_agent.py`: RandomAgent, MemoryRandomAgent | ||
| - `algo_agent.py`: AlgoAgent, MemoryAlgoAgent | ||
| - `q_agent.py`: QAgent, MemoryQAgent | ||
| - `deep_q_agent.py`: DeepQAgent, MemoryDeepQAgent | ||
|
|
||
| --- | ||
|
|
||
| For more details, see the docstrings in each agent class. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| import pytest | ||
| import numpy as np | ||
| from unittest.mock import MagicMock | ||
|
|
||
| from agents.algo_agent import AlgoAgent, MemoryAlgoAgent | ||
| from memory.api.models import MazeObservation | ||
|
|
||
| @pytest.fixture | ||
| def sample_observation(): | ||
| return MazeObservation( | ||
| position=(1, 1), | ||
| target=(2, 2), | ||
| nearby_obstacles=[(0, 1), (1, 0)], | ||
| steps=5, | ||
| ) | ||
|
|
||
|
|
||
| def test_algo_agent_bfs_path(sample_observation): | ||
| agent = AlgoAgent(agent_id="test", action_space=4, search_algo="bfs") | ||
| action = agent.act(sample_observation) | ||
| assert 0 <= action < 4 | ||
|
|
||
|
|
||
| def test_algo_agent_dfs_path(sample_observation): | ||
| agent = AlgoAgent(agent_id="test", action_space=4, search_algo="dfs") | ||
| action = agent.act(sample_observation) | ||
| assert 0 <= action < 4 | ||
|
|
||
|
|
||
| def test_algo_agent_demo_path(sample_observation): | ||
| agent = AlgoAgent(agent_id="test", action_space=4) | ||
| agent.set_demo_path([1, 2]) | ||
| assert agent.act(sample_observation) == 1 | ||
| assert agent.act(sample_observation) == 2 | ||
| # After demo path, should revert to planning | ||
| action = agent.act(sample_observation) | ||
| assert 0 <= action < 4 | ||
|
|
||
|
|
||
| def test_memory_algo_agent_act_returns_valid_action(sample_observation): | ||
| agent = MemoryAlgoAgent(agent_id="test", action_space=4) | ||
| agent.memory = MagicMock() | ||
| agent.memory.retrieve_similar_states.return_value = [] | ||
| action = agent.act(sample_observation) | ||
| assert 0 <= action < 4 | ||
|
|
||
|
|
||
| def test_memory_algo_agent_demo_path(sample_observation): | ||
| agent = MemoryAlgoAgent(agent_id="test", action_space=4) | ||
| agent.set_demo_path([3, 0]) | ||
| assert agent.act(sample_observation) == 3 | ||
| assert agent.act(sample_observation) == 0 | ||
|
|
||
|
|
||
| def test_memory_algo_agent_memory_action(sample_observation): | ||
| agent = MemoryAlgoAgent(agent_id="test", action_space=4) | ||
| agent.memory = MagicMock() | ||
| agent.memory.retrieve_similar_states.return_value = [ | ||
| {"content": {"action": 1, "reward": 1}} | ||
| ] | ||
| np_random_backup = np.random.random | ||
| np.random.random = lambda: 0.5 | ||
| action = agent.act(sample_observation) | ||
| np.random.random = np_random_backup | ||
| assert action == 1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| import pytest | ||
| import numpy as np | ||
| from unittest.mock import MagicMock | ||
| import torch | ||
|
|
||
| from agents.deep_q_agent import DeepQAgent, MemoryDeepQAgent | ||
| from memory.api.models import MazeObservation | ||
|
|
||
| @pytest.fixture | ||
| def sample_observation(): | ||
| return MazeObservation( | ||
| position=(0, 0), | ||
| target=(1, 1), | ||
| nearby_obstacles=[(0, 1)], | ||
| steps=1, | ||
| ) | ||
|
|
||
| @pytest.fixture | ||
| def next_observation(): | ||
| return MazeObservation( | ||
| position=(0, 1), | ||
| target=(1, 1), | ||
| nearby_obstacles=[(1, 1)], | ||
| steps=2, | ||
| ) | ||
|
|
||
| def test_deep_q_agent_epsilon_greedy_action(sample_observation): | ||
| agent = DeepQAgent(agent_id="test", action_space=4) | ||
| np_random_backup = np.random.random | ||
| np.random.random = lambda: 0.05 | ||
| action = agent.act(sample_observation, epsilon=1.0) | ||
| np.random.random = np_random_backup | ||
| assert 0 <= action < 4 | ||
|
|
||
| def test_deep_q_agent_demo_path(sample_observation): | ||
| agent = DeepQAgent(agent_id="test", action_space=4) | ||
| agent.set_demo_path([2, 1]) | ||
| assert agent.act(sample_observation) == 2 | ||
| assert agent.act(sample_observation) == 1 | ||
|
|
||
| def test_deep_q_agent_experience_replay(sample_observation, next_observation): | ||
| agent = DeepQAgent(agent_id="test", action_space=4, batch_size=1) | ||
| agent.remember(sample_observation, 1, 1.0, next_observation, False) | ||
| # Should not raise error | ||
| agent.update() | ||
|
|
||
| def test_memory_deep_q_agent_act_returns_valid_action(sample_observation): | ||
| agent = MemoryDeepQAgent(agent_id="test", action_space=4) | ||
| agent.memory = MagicMock() | ||
| agent.memory.retrieve_similar_states.return_value = [] | ||
| action = agent.act(sample_observation) | ||
| assert 0 <= action < 4 | ||
|
|
||
| def test_memory_deep_q_agent_demo_path(sample_observation): | ||
| agent = MemoryDeepQAgent(agent_id="test", action_space=4) | ||
| agent.set_demo_path([3, 0]) | ||
| assert agent.act(sample_observation) == 3 | ||
| assert agent.act(sample_observation) == 0 | ||
|
|
||
| def test_memory_deep_q_agent_memory_action(sample_observation): | ||
| agent = MemoryDeepQAgent(agent_id="test", action_space=4) | ||
| agent.memory = MagicMock() | ||
| agent.memory.retrieve_similar_states.return_value = [ | ||
| {"content": {"action": 2, "reward": 1}} | ||
| ] | ||
| np_random_backup = np.random.random | ||
| np.random.random = lambda: 0.5 | ||
| action = agent.act(sample_observation) | ||
| np.random.random = np_random_backup | ||
| assert action == 2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| import pytest | ||
| import numpy as np | ||
| from unittest.mock import MagicMock | ||
|
|
||
| from agents.q_agent import QAgent, MemoryQAgent | ||
| from memory.api.models import MazeObservation | ||
|
|
||
| @pytest.fixture | ||
| def sample_observation(): | ||
| return MazeObservation( | ||
| position=(0, 0), | ||
| target=(1, 1), | ||
| nearby_obstacles=[(0, 1)], | ||
| steps=1, | ||
| ) | ||
|
|
||
| @pytest.fixture | ||
| def next_observation(): | ||
| return MazeObservation( | ||
| position=(0, 1), | ||
| target=(1, 1), | ||
| nearby_obstacles=[(1, 1)], | ||
| steps=2, | ||
| ) | ||
|
|
||
| def test_q_agent_epsilon_greedy_action(sample_observation): | ||
| agent = QAgent(agent_id="test", action_space=4) | ||
| # Force random action | ||
| np_random_backup = np.random.random | ||
| np.random.random = lambda: 0.05 | ||
| action = agent.act(sample_observation, epsilon=1.0) | ||
| np.random.random = np_random_backup | ||
| assert 0 <= action < 4 | ||
|
|
||
| def test_q_agent_demo_path(sample_observation): | ||
| agent = QAgent(agent_id="test", action_space=4) | ||
| agent.set_demo_path([2, 1]) | ||
| assert agent.act(sample_observation) == 2 | ||
| assert agent.act(sample_observation) == 1 | ||
|
|
||
| def test_q_agent_q_value_update(sample_observation, next_observation): | ||
| agent = QAgent(agent_id="test", action_space=4) | ||
| action = 1 | ||
| reward = 1.0 | ||
| done = False | ||
| agent.update_q_value(sample_observation, action, reward, next_observation, done) | ||
| state_key = agent._get_state_key(sample_observation) | ||
| assert agent.q_table[state_key][action] != 0 | ||
|
|
||
| def test_memory_q_agent_act_returns_valid_action(sample_observation): | ||
| agent = MemoryQAgent(agent_id="test", action_space=4) | ||
| agent.memory = MagicMock() | ||
| agent.memory.retrieve_similar_states.return_value = [] | ||
| action = agent.act(sample_observation) | ||
| assert 0 <= action < 4 | ||
|
|
||
| def test_memory_q_agent_demo_path(sample_observation): | ||
| agent = MemoryQAgent(agent_id="test", action_space=4) | ||
| agent.set_demo_path([3, 0]) | ||
| assert agent.act(sample_observation) == 3 | ||
| assert agent.act(sample_observation) == 0 | ||
|
|
||
| def test_memory_q_agent_memory_action(sample_observation): | ||
| agent = MemoryQAgent(agent_id="test", action_space=4) | ||
| agent.memory = MagicMock() | ||
| agent.memory.retrieve_similar_states.return_value = [ | ||
| {"content": {"action": 2, "reward": 1}} | ||
| ] | ||
| np_random_backup = np.random.random | ||
| np.random.random = lambda: 0.5 | ||
| action = agent.act(sample_observation) | ||
| np.random.random = np_random_backup | ||
| assert action == 2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| import pytest | ||
| import numpy as np | ||
| from unittest.mock import MagicMock | ||
|
|
||
| from agents.random_agent import RandomAgent, MemoryRandomAgent | ||
| from memory.api.models import MazeObservation, MazeActionSpace | ||
|
|
||
| @pytest.fixture | ||
| def sample_observation(): | ||
| return MazeObservation( | ||
| position=(1, 1), | ||
| target=(2, 2), | ||
| nearby_obstacles=[(0, 1), (1, 0)], | ||
| steps=5, | ||
| ) | ||
|
|
||
|
|
||
| def test_random_agent_act_returns_valid_action(sample_observation): | ||
| agent = RandomAgent(agent_id="test", action_space=4) | ||
| action = agent.act(sample_observation) | ||
| assert 0 <= action < 4 | ||
|
|
||
|
|
||
| def test_random_agent_demo_path(sample_observation): | ||
| agent = RandomAgent(agent_id="test", action_space=4) | ||
| agent.set_demo_path([2, 3, 1]) | ||
| assert agent.act(sample_observation) == 2 | ||
| assert agent.act(sample_observation) == 3 | ||
| assert agent.act(sample_observation) == 1 | ||
| # After demo path, should revert to random | ||
| action = agent.act(sample_observation) | ||
| assert 0 <= action < 4 | ||
|
|
||
|
|
||
| def test_memory_random_agent_act_returns_valid_action(sample_observation, monkeypatch): | ||
| agent = MemoryRandomAgent(agent_id="test", action_space=4) | ||
| # Patch memory.retrieve_similar_states to return empty | ||
| agent.memory = MagicMock() | ||
| agent.memory.retrieve_similar_states.return_value = [] | ||
| action = agent.act(sample_observation) | ||
| assert 0 <= action < 4 | ||
|
|
||
|
|
||
| def test_memory_random_agent_demo_path(sample_observation): | ||
| agent = MemoryRandomAgent(agent_id="test", action_space=4) | ||
| agent.set_demo_path([1, 0]) | ||
| assert agent.act(sample_observation) == 1 | ||
| assert agent.act(sample_observation) == 0 | ||
|
|
||
|
|
||
| def test_memory_random_agent_memory_action(sample_observation): | ||
| agent = MemoryRandomAgent(agent_id="test", action_space=4) | ||
| # Patch memory.retrieve_similar_states to return a memory with action 2 | ||
| agent.memory = MagicMock() | ||
| agent.memory.retrieve_similar_states.return_value = [ | ||
| {"content": {"action": 2, "reward": 1}} | ||
| ] | ||
| # Patch np.random.random to always return 0.5 (> 0.2) | ||
| np_random_backup = np.random.random | ||
| np.random.random = lambda: 0.5 | ||
| action = agent.act(sample_observation) | ||
| np.random.random = np_random_backup | ||
| assert action == 2 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Consider refactoring the Q-table initialization into a helper method to reduce duplication, since similar initialization logic is applied in both select_action and update_q_value methods.