Skip to content

Commit 5f582de

Browse files
author
Luca Carminati
committed
Fix ordering of sampled data in MultiSyncDataCollector
1 parent 8570c25 commit 5f582de

File tree

1 file changed

+14
-30
lines changed

1 file changed

+14
-30
lines changed

torchrl/collectors/collectors.py

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3761,7 +3761,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
37613761
if cat_results is None:
37623762
cat_results = "stack"
37633763

3764-
self.buffers = {}
3764+
self.buffers = [None for _ in range(self.num_workers)]
37653765
dones = [False for _ in range(self.num_workers)]
37663766
workers_frames = [0 for _ in range(self.num_workers)]
37673767
same_device = None
@@ -3844,16 +3844,16 @@ def iterator(self) -> Iterator[TensorDictBase]:
38443844
if preempt:
38453845
# mask buffers if cat, and create a mask if stack
38463846
if cat_results != "stack":
3847-
buffers = {}
3848-
for worker_idx, buffer in self.buffers.items():
3847+
buffers = [None] * self.num_workers
3848+
for worker_idx, buffer in enumerate(filter(None.__ne__, self.buffers)):
38493849
valid = buffer.get(("collector", "traj_ids")) != -1
38503850
if valid.ndim > 2:
38513851
valid = valid.flatten(0, -2)
38523852
if valid.ndim == 2:
38533853
valid = valid.any(0)
38543854
buffers[worker_idx] = buffer[..., valid]
38553855
else:
3856-
for buffer in self.buffers.values():
3856+
for buffer in filter(None.__ne__, self.buffers):
38573857
with buffer.unlock_():
38583858
buffer.set(
38593859
("collector", "mask"),
@@ -3886,7 +3886,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
38863886
# we have to correct the traj_ids to make sure that they don't overlap
38873887
# We can count the number of frames collected for free in this loop
38883888
n_collected = 0
3889-
for idx in buffers.keys():
3889+
for idx,buffer in enumerate(filter(None.__ne__, buffers)):
38903890
buffer = buffers[idx]
38913891
traj_ids = buffer.get(("collector", "traj_ids"))
38923892
if preempt:
@@ -3901,7 +3901,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
39013901
if same_device is None:
39023902
prev_device = None
39033903
same_device = True
3904-
for item in self.buffers.values():
3904+
for item in filter(None.__ne__, self.buffers):
39053905
if prev_device is None:
39063906
prev_device = item.device
39073907
else:
@@ -3912,33 +3912,21 @@ def iterator(self) -> Iterator[TensorDictBase]:
39123912
torch.stack if self._use_buffers else TensorDict.maybe_dense_stack
39133913
)
39143914
if same_device:
3915-
self.out_buffer = stack(list(buffers.values()), 0)
3915+
self.out_buffer = stack([item for item in buffers if item is not None], 0)
39163916
else:
3917-
self.out_buffer = stack(
3918-
[item.cpu() for item in buffers.values()], 0
3919-
)
3917+
self.out_buffer = stack([item.cpu() for item in buffers if item is not None], 0)
39203918
else:
39213919
if self._use_buffers is None:
3922-
torchrl_logger.warning(
3923-
"use_buffer not specified and not yet inferred from data, assuming `True`."
3924-
)
3920+
torchrl_logger.warning("use_buffer not specified and not yet inferred from data, assuming `True`.")
39253921
elif not self._use_buffers:
3926-
raise RuntimeError(
3927-
"Cannot concatenate results with use_buffers=False"
3928-
)
3922+
raise RuntimeError("Cannot concatenate results with use_buffers=False")
39293923
try:
39303924
if same_device:
3931-
self.out_buffer = torch.cat(list(buffers.values()), cat_results)
3925+
self.out_buffer = torch.cat([item for item in buffers if item is not None], cat_results)
39323926
else:
3933-
self.out_buffer = torch.cat(
3934-
[item.cpu() for item in buffers.values()], cat_results
3935-
)
3927+
self.out_buffer = torch.cat([item.cpu() for item in buffers if item is not None], cat_results)
39363928
except RuntimeError as err:
3937-
if (
3938-
preempt
3939-
and cat_results != -1
3940-
and "Sizes of tensors must match" in str(err)
3941-
):
3929+
if preempt and cat_results != -1 and "Sizes of tensors must match" in str(err):
39423930
raise RuntimeError(
39433931
"The value provided to cat_results isn't compatible with the collectors outputs. "
39443932
"Consider using `cat_results=-1`."
@@ -3956,11 +3944,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
39563944
self._frames += n_collected
39573945

39583946
if self.postprocs:
3959-
self.postprocs = (
3960-
self.postprocs.to(out.device)
3961-
if hasattr(self.postprocs, "to")
3962-
else self.postprocs
3963-
)
3947+
self.postprocs = self.postprocs.to(out.device) if hasattr(self.postprocs, "to") else self.postprocs
39643948
out = self.postprocs(out)
39653949
if self._exclude_private_keys:
39663950
excluded_keys = [key for key in out.keys() if key.startswith("_")]

0 commit comments

Comments
 (0)