@@ -1717,6 +1717,139 @@ def env_fn():
17171717 total_frames = frames_per_batch * 100 ,
17181718 )
17191719
1720+ class FixedIDEnv (EnvBase ):
1721+ """
1722+ A simple mock environment that returns a fixed ID as its sole observation.
1723+
1724+ This environment is designed to test MultiSyncDataCollector ordering.
1725+ Each environment instance is initialized with a unique env_id, which it
1726+ returns as the observation at every step.
1727+ """
1728+
1729+ def __init__ (self , env_id : int , max_steps : int = 10 , ** kwargs ):
1730+ """
1731+ Args:
1732+ env_id: The ID to return as observation. This will be returned as a tensor.
1733+ max_steps: Maximum number of steps before the environment terminates.
1734+ """
1735+ super ().__init__ (device = "cpu" , batch_size = torch .Size ([]))
1736+ self .env_id = env_id
1737+ self .max_steps = max_steps
1738+ self ._step_count = 0
1739+
1740+ # Define specs
1741+ self .observation_spec = Composite (
1742+ observation = Unbounded (shape = (1 ,), dtype = torch .float32 )
1743+ )
1744+ self .action_spec = Composite (
1745+ action = Unbounded (shape = (1 ,), dtype = torch .float32 )
1746+ )
1747+ self .reward_spec = Composite (
1748+ reward = Unbounded (shape = (1 ,), dtype = torch .float32 )
1749+ )
1750+ self .done_spec = Composite (
1751+ done = Unbounded (shape = (1 ,), dtype = torch .bool ),
1752+ terminated = Unbounded (shape = (1 ,), dtype = torch .bool ),
1753+ truncated = Unbounded (shape = (1 ,), dtype = torch .bool ),
1754+ )
1755+
1756+ def _reset (self , tensordict : TensorDict | None = None , ** kwargs ) -> TensorDict :
1757+ """Reset the environment and return initial observation."""
1758+ # Add random sleep to simulate real-world timing variations
1759+ # This helps test that the collector properly handles different reset times
1760+ time .sleep (torch .rand (1 ).item () * 0.01 ) # Random sleep up to 10ms
1761+
1762+ self ._step_count = 0
1763+ return TensorDict (
1764+ {
1765+ "observation" : torch .tensor (
1766+ [float (self .env_id )], dtype = torch .float32
1767+ ),
1768+ "done" : torch .tensor ([False ], dtype = torch .bool ),
1769+ "terminated" : torch .tensor ([False ], dtype = torch .bool ),
1770+ "truncated" : torch .tensor ([False ], dtype = torch .bool ),
1771+ },
1772+ batch_size = self .batch_size ,
1773+ )
1774+
1775+ def _step (self , tensordict : TensorDict ) -> TensorDict :
1776+ """Execute one step and return the env_id as observation."""
1777+ self ._step_count += 1
1778+ done = self ._step_count >= self .max_steps
1779+
1780+ return TensorDict (
1781+ {
1782+ "observation" : torch .tensor (
1783+ [float (self .env_id )], dtype = torch .float32
1784+ ),
1785+ "reward" : torch .tensor ([1.0 ], dtype = torch .float32 ),
1786+ "done" : torch .tensor ([done ], dtype = torch .bool ),
1787+ "terminated" : torch .tensor ([done ], dtype = torch .bool ),
1788+ "truncated" : torch .tensor ([False ], dtype = torch .bool ),
1789+ },
1790+ batch_size = self .batch_size ,
1791+ )
1792+
1793+ def _set_seed (self , seed : int | None ) -> int | None :
1794+ """Set the seed for reproducibility."""
1795+ if seed is not None :
1796+ torch .manual_seed (seed )
1797+ return seed
1798+
1799+ @pytest .mark .parametrize ("num_envs" , [8 ])
1800+ def test_multi_sync_data_collector_ordering (self , num_envs : int ):
1801+ """
1802+ Test that MultiSyncDataCollector returns data in the correct order.
1803+
1804+ We create num_envs environments, each returning its env_id as the observation.
1805+ After collection, we verify that the observations correspond to the correct env_ids in order
1806+ """
1807+ frames_per_batch = num_envs * 5 # Collect 5 steps per environment
1808+
1809+ # Create environment factories using partial - one for each env_id
1810+ # This pattern mirrors CrossPlayEvaluator._rollout usage
1811+ env_factories = [
1812+ functools .partial (self .FixedIDEnv , env_id = i , max_steps = 10 )
1813+ for i in range (num_envs )
1814+ ]
1815+
1816+ # Create policy factories using partial
1817+ policy = ParametricPolicy ()
1818+
1819+ # Initialize MultiSyncDataCollector
1820+ collector = MultiSyncDataCollector (
1821+ create_env_fn = env_factories ,
1822+ policy = policy ,
1823+ frames_per_batch = frames_per_batch ,
1824+ total_frames = frames_per_batch ,
1825+ device = "cpu" ,
1826+ )
1827+
1828+ # Collect one batch
1829+ for batch in collector :
1830+ # Verify that each environment's observations match its env_id
1831+ # batch has shape [num_envs, frames_per_env]
1832+ for env_idx in range (num_envs ):
1833+ env_data = batch [env_idx ]
1834+ observations = env_data ["observation" ]
1835+
1836+ # All observations from this environment should equal its env_id
1837+ expected_id = float (env_idx )
1838+ actual_ids = observations .flatten ().unique ()
1839+
1840+ assert len (actual_ids ) == 1 , (
1841+ f"Env { env_idx } should only produce observations with value { expected_id } , "
1842+ f"but got { actual_ids .tolist ()} "
1843+ )
1844+ assert (
1845+ actual_ids [0 ].item () == expected_id
1846+ ), f"Environment { env_idx } should produce observation { expected_id } , but got { actual_ids [0 ].item ()} "
1847+
1848+ # Only process the first batch
1849+ break
1850+
1851+ collector .shutdown ()
1852+
17201853
17211854class TestCollectorDevices :
17221855 class DeviceLessEnv (EnvBase ):
0 commit comments