Skip to content

Commit 599d93d

Browse files
author
Luca Carminati
committed
Revert None filtering logic and fix bugs
1 parent c428fd4 commit 599d93d

File tree

1 file changed

+27
-8
lines changed

1 file changed

+27
-8
lines changed

torchrl/collectors/collectors.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3843,15 +3843,17 @@ def iterator(self) -> Iterator[TensorDictBase]:
38433843
# mask buffers if cat, and create a mask if stack
38443844
if cat_results != "stack":
38453845
buffers = [None] * self.num_workers
3846-
for idx, buffer in enumerate(filter(None.__ne__, self.buffers)):
3846+
for worker_idx, buffer in enumerate(
3847+
filter(lambda x: x is not None, self.buffers)
3848+
):
38473849
valid = buffer.get(("collector", "traj_ids")) != -1
38483850
if valid.ndim > 2:
38493851
valid = valid.flatten(0, -2)
38503852
if valid.ndim == 2:
38513853
valid = valid.any(0)
3852-
buffers[idx] = buffer[..., valid]
3854+
buffers[worker_idx] = buffer[..., valid]
38533855
else:
3854-
for buffer in filter(None.__ne__, self.buffers):
3856+
for buffer in filter(lambda x: x is not None, self.buffers):
38553857
with buffer.unlock_():
38563858
buffer.set(
38573859
("collector", "mask"),
@@ -3861,6 +3863,11 @@ def iterator(self) -> Iterator[TensorDictBase]:
38613863
else:
38623864
buffers = self.buffers
38633865

3866+
# Skip frame counting if this worker didn't send data this iteration
3867+
# (happens when reusing buffers or on first iteration with some workers)
3868+
if self.buffers[idx] is None:
3869+
continue
3870+
38643871
workers_frames[idx] = workers_frames[idx] + buffers[idx].numel()
38653872

38663873
if workers_frames[idx] >= self.total_frames:
@@ -3892,7 +3899,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
38923899
if same_device is None:
38933900
prev_device = None
38943901
same_device = True
3895-
for item in filter(None.__ne__, self.buffers):
3902+
for item in filter(lambda x: x is not None, self.buffers):
38963903
if prev_device is None:
38973904
prev_device = item.device
38983905
else:
@@ -3912,9 +3919,13 @@ def iterator(self) -> Iterator[TensorDictBase]:
39123919
)
39133920
else:
39143921
if self._use_buffers is None:
3915-
torchrl_logger.warning("use_buffer not specified and not yet inferred from data, assuming `True`.")
3922+
torchrl_logger.warning(
3923+
"use_buffer not specified and not yet inferred from data, assuming `True`."
3924+
)
39163925
elif not self._use_buffers:
3917-
raise RuntimeError("Cannot concatenate results with use_buffers=False")
3926+
raise RuntimeError(
3927+
"Cannot concatenate results with use_buffers=False"
3928+
)
39183929
try:
39193930
if same_device:
39203931
self.out_buffer = torch.cat(
@@ -3926,7 +3937,11 @@ def iterator(self) -> Iterator[TensorDictBase]:
39263937
cat_results,
39273938
)
39283939
except RuntimeError as err:
3929-
if preempt and cat_results != -1 and "Sizes of tensors must match" in str(err):
3940+
if (
3941+
preempt
3942+
and cat_results != -1
3943+
and "Sizes of tensors must match" in str(err)
3944+
):
39303945
raise RuntimeError(
39313946
"The value provided to cat_results isn't compatible with the collectors outputs. "
39323947
"Consider using `cat_results=-1`."
@@ -3944,7 +3959,11 @@ def iterator(self) -> Iterator[TensorDictBase]:
39443959
self._frames += n_collected
39453960

39463961
if self.postprocs:
3947-
self.postprocs = self.postprocs.to(out.device) if hasattr(self.postprocs, "to") else self.postprocs
3962+
self.postprocs = (
3963+
self.postprocs.to(out.device)
3964+
if hasattr(self.postprocs, "to")
3965+
else self.postprocs
3966+
)
39483967
out = self.postprocs(out)
39493968
if self._exclude_private_keys:
39503969
excluded_keys = [key for key in out.keys() if key.startswith("_")]

0 commit comments

Comments
 (0)