@@ -3845,7 +3845,9 @@ def iterator(self) -> Iterator[TensorDictBase]:
38453845 # mask buffers if cat, and create a mask if stack
38463846 if cat_results != "stack" :
38473847 buffers = [None ] * self .num_workers
3848- for worker_idx , buffer in enumerate (filter (None .__ne__ , self .buffers )):
3848+ for worker_idx , buffer in enumerate (
3849+ filter (None .__ne__ , self .buffers )
3850+ ):
38493851 valid = buffer .get (("collector" , "traj_ids" )) != - 1
38503852 if valid .ndim > 2 :
38513853 valid = valid .flatten (0 , - 2 )
@@ -3886,7 +3888,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
38863888 # we have to correct the traj_ids to make sure that they don't overlap
38873889 # We can count the number of frames collected for free in this loop
38883890 n_collected = 0
3889- for idx ,buffer in enumerate (filter (None .__ne__ , buffers )):
3891+ for idx , buffer in enumerate (filter (None .__ne__ , buffers )):
38903892 buffer = buffers [idx ]
38913893 traj_ids = buffer .get (("collector" , "traj_ids" ))
38923894 if preempt :
@@ -3912,19 +3914,28 @@ def iterator(self) -> Iterator[TensorDictBase]:
39123914 torch .stack if self ._use_buffers else TensorDict .maybe_dense_stack
39133915 )
39143916 if same_device :
3915- self .out_buffer = stack ([item for item in buffers if item is not None ], 0 )
3917+ self .out_buffer = stack (
3918+ [item for item in buffers if item is not None ], 0
3919+ )
39163920 else :
3917- self .out_buffer = stack ([item .cpu () for item in buffers if item is not None ], 0 )
3921+ self .out_buffer = stack (
3922+ [item .cpu () for item in buffers if item is not None ], 0
3923+ )
39183924 else :
39193925 if self ._use_buffers is None :
39203926 torchrl_logger .warning ("use_buffer not specified and not yet inferred from data, assuming `True`." )
39213927 elif not self ._use_buffers :
39223928 raise RuntimeError ("Cannot concatenate results with use_buffers=False" )
39233929 try :
39243930 if same_device :
3925- self .out_buffer = torch .cat ([item for item in buffers if item is not None ], cat_results )
3931+ self .out_buffer = torch .cat (
3932+ [item for item in buffers if item is not None ], cat_results
3933+ )
39263934 else :
3927- self .out_buffer = torch .cat ([item .cpu () for item in buffers if item is not None ], cat_results )
3935+ self .out_buffer = torch .cat (
3936+ [item .cpu () for item in buffers if item is not None ],
3937+ cat_results ,
3938+ )
39283939 except RuntimeError as err :
39293940 if preempt and cat_results != - 1 and "Sizes of tensors must match" in str (err ):
39303941 raise RuntimeError (
0 commit comments