Skip to content

Commit 73eed6c

Browse files
author
Luca Carminati
committed
Fix ordering of sampled data in MultiSyncDataCollector
1 parent 5f582de commit 73eed6c

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

torchrl/collectors/collectors.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3845,7 +3845,9 @@ def iterator(self) -> Iterator[TensorDictBase]:
38453845
# mask buffers if cat, and create a mask if stack
38463846
if cat_results != "stack":
38473847
buffers = [None] * self.num_workers
3848-
for worker_idx, buffer in enumerate(filter(None.__ne__, self.buffers)):
3848+
for worker_idx, buffer in enumerate(
3849+
filter(None.__ne__, self.buffers)
3850+
):
38493851
valid = buffer.get(("collector", "traj_ids")) != -1
38503852
if valid.ndim > 2:
38513853
valid = valid.flatten(0, -2)
@@ -3886,7 +3888,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
38863888
# we have to correct the traj_ids to make sure that they don't overlap
38873889
# We can count the number of frames collected for free in this loop
38883890
n_collected = 0
3889-
for idx,buffer in enumerate(filter(None.__ne__, buffers)):
3891+
for idx, buffer in enumerate(filter(None.__ne__, buffers)):
38903892
buffer = buffers[idx]
38913893
traj_ids = buffer.get(("collector", "traj_ids"))
38923894
if preempt:
@@ -3912,19 +3914,28 @@ def iterator(self) -> Iterator[TensorDictBase]:
39123914
torch.stack if self._use_buffers else TensorDict.maybe_dense_stack
39133915
)
39143916
if same_device:
3915-
self.out_buffer = stack([item for item in buffers if item is not None], 0)
3917+
self.out_buffer = stack(
3918+
[item for item in buffers if item is not None], 0
3919+
)
39163920
else:
3917-
self.out_buffer = stack([item.cpu() for item in buffers if item is not None], 0)
3921+
self.out_buffer = stack(
3922+
[item.cpu() for item in buffers if item is not None], 0
3923+
)
39183924
else:
39193925
if self._use_buffers is None:
39203926
torchrl_logger.warning("use_buffer not specified and not yet inferred from data, assuming `True`.")
39213927
elif not self._use_buffers:
39223928
raise RuntimeError("Cannot concatenate results with use_buffers=False")
39233929
try:
39243930
if same_device:
3925-
self.out_buffer = torch.cat([item for item in buffers if item is not None], cat_results)
3931+
self.out_buffer = torch.cat(
3932+
[item for item in buffers if item is not None], cat_results
3933+
)
39263934
else:
3927-
self.out_buffer = torch.cat([item.cpu() for item in buffers if item is not None], cat_results)
3935+
self.out_buffer = torch.cat(
3936+
[item.cpu() for item in buffers if item is not None],
3937+
cat_results,
3938+
)
39283939
except RuntimeError as err:
39293940
if preempt and cat_results != -1 and "Sizes of tensors must match" in str(err):
39303941
raise RuntimeError(

0 commit comments

Comments
 (0)