Skip to content

Commit bb5fad6

Browse files
author
Luca Carminati
committed
Fix tests
1 parent 73eed6c commit bb5fad6

File tree

2 files changed

+138
-16
lines changed

2 files changed

+138
-16
lines changed

test/test_collector.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,6 +1717,139 @@ def env_fn():
17171717
total_frames=frames_per_batch * 100,
17181718
)
17191719

1720+
class FixedIDEnv(EnvBase):
1721+
"""
1722+
A simple mock environment that returns a fixed ID as its sole observation.
1723+
1724+
This environment is designed to test MultiSyncDataCollector ordering.
1725+
Each environment instance is initialized with a unique env_id, which it
1726+
returns as the observation at every step.
1727+
"""
1728+
1729+
def __init__(self, env_id: int, max_steps: int = 10, **kwargs):
1730+
"""
1731+
Args:
1732+
env_id: The ID to return as observation. This will be returned as a tensor.
1733+
max_steps: Maximum number of steps before the environment terminates.
1734+
"""
1735+
super().__init__(device="cpu", batch_size=torch.Size([]))
1736+
self.env_id = env_id
1737+
self.max_steps = max_steps
1738+
self._step_count = 0
1739+
1740+
# Define specs
1741+
self.observation_spec = Composite(
1742+
observation=Unbounded(shape=(1,), dtype=torch.float32)
1743+
)
1744+
self.action_spec = Composite(
1745+
action=Unbounded(shape=(1,), dtype=torch.float32)
1746+
)
1747+
self.reward_spec = Composite(
1748+
reward=Unbounded(shape=(1,), dtype=torch.float32)
1749+
)
1750+
self.done_spec = Composite(
1751+
done=Unbounded(shape=(1,), dtype=torch.bool),
1752+
terminated=Unbounded(shape=(1,), dtype=torch.bool),
1753+
truncated=Unbounded(shape=(1,), dtype=torch.bool),
1754+
)
1755+
1756+
def _reset(self, tensordict: TensorDict | None = None, **kwargs) -> TensorDict:
1757+
"""Reset the environment and return initial observation."""
1758+
# Add random sleep to simulate real-world timing variations
1759+
# This helps test that the collector properly handles different reset times
1760+
time.sleep(torch.rand(1).item() * 0.01) # Random sleep up to 10ms
1761+
1762+
self._step_count = 0
1763+
return TensorDict(
1764+
{
1765+
"observation": torch.tensor(
1766+
[float(self.env_id)], dtype=torch.float32
1767+
),
1768+
"done": torch.tensor([False], dtype=torch.bool),
1769+
"terminated": torch.tensor([False], dtype=torch.bool),
1770+
"truncated": torch.tensor([False], dtype=torch.bool),
1771+
},
1772+
batch_size=self.batch_size,
1773+
)
1774+
1775+
def _step(self, tensordict: TensorDict) -> TensorDict:
1776+
"""Execute one step and return the env_id as observation."""
1777+
self._step_count += 1
1778+
done = self._step_count >= self.max_steps
1779+
1780+
return TensorDict(
1781+
{
1782+
"observation": torch.tensor(
1783+
[float(self.env_id)], dtype=torch.float32
1784+
),
1785+
"reward": torch.tensor([1.0], dtype=torch.float32),
1786+
"done": torch.tensor([done], dtype=torch.bool),
1787+
"terminated": torch.tensor([done], dtype=torch.bool),
1788+
"truncated": torch.tensor([False], dtype=torch.bool),
1789+
},
1790+
batch_size=self.batch_size,
1791+
)
1792+
1793+
def _set_seed(self, seed: int | None) -> int | None:
1794+
"""Set the seed for reproducibility."""
1795+
if seed is not None:
1796+
torch.manual_seed(seed)
1797+
return seed
1798+
1799+
@pytest.mark.parametrize("num_envs", [8])
1800+
def test_multi_sync_data_collector_ordering(self, num_envs: int):
1801+
"""
1802+
Test that MultiSyncDataCollector returns data in the correct order.
1803+
1804+
We create num_envs environments, each returning its env_id as the observation.
1805+
After collection, we verify that the observations correspond to the correct env_ids in order
1806+
"""
1807+
frames_per_batch = num_envs * 5 # Collect 5 steps per environment
1808+
1809+
# Create environment factories using partial - one for each env_id
1810+
# This pattern mirrors CrossPlayEvaluator._rollout usage
1811+
env_factories = [
1812+
functools.partial(self.FixedIDEnv, env_id=i, max_steps=10)
1813+
for i in range(num_envs)
1814+
]
1815+
1816+
# Create policy factories using partial
1817+
policy = ParametricPolicy()
1818+
1819+
# Initialize MultiSyncDataCollector
1820+
collector = MultiSyncDataCollector(
1821+
create_env_fn=env_factories,
1822+
policy=policy,
1823+
frames_per_batch=frames_per_batch,
1824+
total_frames=frames_per_batch,
1825+
device="cpu",
1826+
)
1827+
1828+
# Collect one batch
1829+
for batch in collector:
1830+
# Verify that each environment's observations match its env_id
1831+
# batch has shape [num_envs, frames_per_env]
1832+
for env_idx in range(num_envs):
1833+
env_data = batch[env_idx]
1834+
observations = env_data["observation"]
1835+
1836+
# All observations from this environment should equal its env_id
1837+
expected_id = float(env_idx)
1838+
actual_ids = observations.flatten().unique()
1839+
1840+
assert len(actual_ids) == 1, (
1841+
f"Env {env_idx} should only produce observations with value {expected_id}, "
1842+
f"but got {actual_ids.tolist()}"
1843+
)
1844+
assert (
1845+
actual_ids[0].item() == expected_id
1846+
), f"Environment {env_idx} should produce observation {expected_id}, but got {actual_ids[0].item()}"
1847+
1848+
# Only process the first batch
1849+
break
1850+
1851+
collector.shutdown()
1852+
17201853

17211854
class TestCollectorDevices:
17221855
class DeviceLessEnv(EnvBase):

torchrl/collectors/collectors.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3760,7 +3760,6 @@ def iterator(self) -> Iterator[TensorDictBase]:
37603760
cat_results = self.cat_results
37613761
if cat_results is None:
37623762
cat_results = "stack"
3763-
37643763
self.buffers = [None for _ in range(self.num_workers)]
37653764
dones = [False for _ in range(self.num_workers)]
37663765
workers_frames = [0 for _ in range(self.num_workers)]
@@ -3781,7 +3780,6 @@ def iterator(self) -> Iterator[TensorDictBase]:
37813780
msg = "continue_random"
37823781
else:
37833782
msg = "continue"
3784-
# Debug: sending 'continue'
37853783
self.pipes[idx].send((None, msg))
37863784

37873785
self._iter += 1
@@ -3845,15 +3843,13 @@ def iterator(self) -> Iterator[TensorDictBase]:
38453843
# mask buffers if cat, and create a mask if stack
38463844
if cat_results != "stack":
38473845
buffers = [None] * self.num_workers
3848-
for worker_idx, buffer in enumerate(
3849-
filter(None.__ne__, self.buffers)
3850-
):
3846+
for idx, buffer in enumerate(filter(None.__ne__, self.buffers)):
38513847
valid = buffer.get(("collector", "traj_ids")) != -1
38523848
if valid.ndim > 2:
38533849
valid = valid.flatten(0, -2)
38543850
if valid.ndim == 2:
38553851
valid = valid.any(0)
3856-
buffers[worker_idx] = buffer[..., valid]
3852+
buffers[idx] = buffer[..., valid]
38573853
else:
38583854
for buffer in filter(None.__ne__, self.buffers):
38593855
with buffer.unlock_():
@@ -3865,11 +3861,6 @@ def iterator(self) -> Iterator[TensorDictBase]:
38653861
else:
38663862
buffers = self.buffers
38673863

3868-
# Skip frame counting if this worker didn't send data this iteration
3869-
# (happens when reusing buffers or on first iteration with some workers)
3870-
if idx not in buffers:
3871-
continue
3872-
38733864
workers_frames[idx] = workers_frames[idx] + buffers[idx].numel()
38743865

38753866
if workers_frames[idx] >= self.total_frames:
@@ -3878,17 +3869,15 @@ def iterator(self) -> Iterator[TensorDictBase]:
38783869
if self.replay_buffer is not None:
38793870
yield
38803871
self._frames += sum(
3881-
[
3882-
self.frames_per_batch_worker(worker_idx)
3883-
for worker_idx in range(self.num_workers)
3884-
]
3872+
self.frames_per_batch_worker(worker_idx)
3873+
for worker_idx in range(self.num_workers)
38853874
)
38863875
continue
38873876

38883877
# we have to correct the traj_ids to make sure that they don't overlap
38893878
# We can count the number of frames collected for free in this loop
38903879
n_collected = 0
3891-
for idx, buffer in enumerate(filter(None.__ne__, buffers)):
3880+
for idx in range(self.num_workers):
38923881
buffer = buffers[idx]
38933882
traj_ids = buffer.get(("collector", "traj_ids"))
38943883
if preempt:

0 commit comments

Comments
 (0)