@@ -3843,15 +3843,17 @@ def iterator(self) -> Iterator[TensorDictBase]:
38433843 # mask buffers if cat, and create a mask if stack
38443844 if cat_results != "stack" :
38453845 buffers = [None ] * self .num_workers
3846- for idx , buffer in enumerate (filter (None .__ne__ , self .buffers )):
3846+ for worker_idx , buffer in enumerate (
3847+ filter (lambda x : x is not None , self .buffers )
3848+ ):
38473849 valid = buffer .get (("collector" , "traj_ids" )) != - 1
38483850 if valid .ndim > 2 :
38493851 valid = valid .flatten (0 , - 2 )
38503852 if valid .ndim == 2 :
38513853 valid = valid .any (0 )
3852- buffers [idx ] = buffer [..., valid ]
3854+ buffers [worker_idx ] = buffer [..., valid ]
38533855 else :
3854- for buffer in filter (None . __ne__ , self .buffers ):
3856+ for buffer in filter (lambda x : x is not None , self .buffers ):
38553857 with buffer .unlock_ ():
38563858 buffer .set (
38573859 ("collector" , "mask" ),
@@ -3861,6 +3863,11 @@ def iterator(self) -> Iterator[TensorDictBase]:
38613863 else :
38623864 buffers = self .buffers
38633865
3866+ # Skip frame counting if this worker didn't send data this iteration
3867+ # (happens when reusing buffers or on first iteration with some workers)
3868+ if self .buffers [idx ] is None :
3869+ continue
3870+
38643871 workers_frames [idx ] = workers_frames [idx ] + buffers [idx ].numel ()
38653872
38663873 if workers_frames [idx ] >= self .total_frames :
@@ -3892,7 +3899,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
38923899 if same_device is None :
38933900 prev_device = None
38943901 same_device = True
3895- for item in filter (None . __ne__ , self .buffers ):
3902+ for item in filter (lambda x : x is not None , self .buffers ):
38963903 if prev_device is None :
38973904 prev_device = item .device
38983905 else :
@@ -3912,9 +3919,13 @@ def iterator(self) -> Iterator[TensorDictBase]:
39123919 )
39133920 else :
39143921 if self ._use_buffers is None :
3915- torchrl_logger .warning ("use_buffer not specified and not yet inferred from data, assuming `True`." )
3922+ torchrl_logger .warning (
3923+ "use_buffer not specified and not yet inferred from data, assuming `True`."
3924+ )
39163925 elif not self ._use_buffers :
3917- raise RuntimeError ("Cannot concatenate results with use_buffers=False" )
3926+ raise RuntimeError (
3927+ "Cannot concatenate results with use_buffers=False"
3928+ )
39183929 try :
39193930 if same_device :
39203931 self .out_buffer = torch .cat (
@@ -3926,7 +3937,11 @@ def iterator(self) -> Iterator[TensorDictBase]:
39263937 cat_results ,
39273938 )
39283939 except RuntimeError as err :
3929- if preempt and cat_results != - 1 and "Sizes of tensors must match" in str (err ):
3940+ if (
3941+ preempt
3942+ and cat_results != - 1
3943+ and "Sizes of tensors must match" in str (err )
3944+ ):
39303945 raise RuntimeError (
39313946 "The value provided to cat_results isn't compatible with the collectors outputs. "
39323947 "Consider using `cat_results=-1`."
@@ -3944,7 +3959,11 @@ def iterator(self) -> Iterator[TensorDictBase]:
39443959 self ._frames += n_collected
39453960
39463961 if self .postprocs :
3947- self .postprocs = self .postprocs .to (out .device ) if hasattr (self .postprocs , "to" ) else self .postprocs
3962+ self .postprocs = (
3963+ self .postprocs .to (out .device )
3964+ if hasattr (self .postprocs , "to" )
3965+ else self .postprocs
3966+ )
39483967 out = self .postprocs (out )
39493968 if self ._exclude_private_keys :
39503969 excluded_keys = [key for key in out .keys () if key .startswith ("_" )]
0 commit comments