@@ -3761,7 +3761,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
37613761 if cat_results is None :
37623762 cat_results = "stack"
37633763
3764- self .buffers = {}
3764+ self .buffers = [ None for _ in range ( self . num_workers )]
37653765 dones = [False for _ in range (self .num_workers )]
37663766 workers_frames = [0 for _ in range (self .num_workers )]
37673767 same_device = None
@@ -3844,16 +3844,16 @@ def iterator(self) -> Iterator[TensorDictBase]:
38443844 if preempt :
38453845 # mask buffers if cat, and create a mask if stack
38463846 if cat_results != "stack" :
3847- buffers = {}
3848- for worker_idx , buffer in self .buffers . items ( ):
3847+ buffers = [ None ] * self . num_workers
3848+ for worker_idx , buffer in enumerate ( filter ( None . __ne__ , self .buffers ) ):
38493849 valid = buffer .get (("collector" , "traj_ids" )) != - 1
38503850 if valid .ndim > 2 :
38513851 valid = valid .flatten (0 , - 2 )
38523852 if valid .ndim == 2 :
38533853 valid = valid .any (0 )
38543854 buffers [worker_idx ] = buffer [..., valid ]
38553855 else :
3856- for buffer in self .buffers . values ( ):
3856+ for buffer in filter ( None . __ne__ , self .buffers ):
38573857 with buffer .unlock_ ():
38583858 buffer .set (
38593859 ("collector" , "mask" ),
@@ -3886,7 +3886,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
38863886 # we have to correct the traj_ids to make sure that they don't overlap
38873887 # We can count the number of frames collected for free in this loop
38883888 n_collected = 0
3889- for idx in buffers . keys ( ):
3889+ for idx , buffer in enumerate ( filter ( None . __ne__ , buffers ) ):
38903890 buffer = buffers [idx ]
38913891 traj_ids = buffer .get (("collector" , "traj_ids" ))
38923892 if preempt :
@@ -3901,7 +3901,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
39013901 if same_device is None :
39023902 prev_device = None
39033903 same_device = True
3904- for item in self .buffers . values ( ):
3904+ for item in filter ( None . __ne__ , self .buffers ):
39053905 if prev_device is None :
39063906 prev_device = item .device
39073907 else :
@@ -3912,33 +3912,21 @@ def iterator(self) -> Iterator[TensorDictBase]:
39123912 torch .stack if self ._use_buffers else TensorDict .maybe_dense_stack
39133913 )
39143914 if same_device :
3915- self .out_buffer = stack (list ( buffers . values ()) , 0 )
3915+ self .out_buffer = stack ([ item for item in buffers if item is not None ] , 0 )
39163916 else :
3917- self .out_buffer = stack (
3918- [item .cpu () for item in buffers .values ()], 0
3919- )
3917+ self .out_buffer = stack ([item .cpu () for item in buffers if item is not None ], 0 )
39203918 else :
39213919 if self ._use_buffers is None :
3922- torchrl_logger .warning (
3923- "use_buffer not specified and not yet inferred from data, assuming `True`."
3924- )
3920+ torchrl_logger .warning ("use_buffer not specified and not yet inferred from data, assuming `True`." )
39253921 elif not self ._use_buffers :
3926- raise RuntimeError (
3927- "Cannot concatenate results with use_buffers=False"
3928- )
3922+ raise RuntimeError ("Cannot concatenate results with use_buffers=False" )
39293923 try :
39303924 if same_device :
3931- self .out_buffer = torch .cat (list ( buffers . values ()) , cat_results )
3925+ self .out_buffer = torch .cat ([ item for item in buffers if item is not None ] , cat_results )
39323926 else :
3933- self .out_buffer = torch .cat (
3934- [item .cpu () for item in buffers .values ()], cat_results
3935- )
3927+ self .out_buffer = torch .cat ([item .cpu () for item in buffers if item is not None ], cat_results )
39363928 except RuntimeError as err :
3937- if (
3938- preempt
3939- and cat_results != - 1
3940- and "Sizes of tensors must match" in str (err )
3941- ):
3929+ if preempt and cat_results != - 1 and "Sizes of tensors must match" in str (err ):
39423930 raise RuntimeError (
39433931 "The value provided to cat_results isn't compatible with the collectors outputs. "
39443932 "Consider using `cat_results=-1`."
@@ -3956,11 +3944,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
39563944 self ._frames += n_collected
39573945
39583946 if self .postprocs :
3959- self .postprocs = (
3960- self .postprocs .to (out .device )
3961- if hasattr (self .postprocs , "to" )
3962- else self .postprocs
3963- )
3947+ self .postprocs = self .postprocs .to (out .device ) if hasattr (self .postprocs , "to" ) else self .postprocs
39643948 out = self .postprocs (out )
39653949 if self ._exclude_private_keys :
39663950 excluded_keys = [key for key in out .keys () if key .startswith ("_" )]
0 commit comments