Skip to content

Commit

Permalink
[Quality] RB constuctors cleanup (pytorch#945)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 2, 2023
1 parent 732e3a2 commit 501b3af
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 15 deletions.
63 changes: 62 additions & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
@pytest.mark.parametrize("writer", [writers.RoundRobinWriter])
@pytest.mark.parametrize("storage", [ListStorage, LazyTensorStorage, LazyMemmapStorage])
@pytest.mark.parametrize("size", [3, 5, 100])
class TestPrototypeBuffers:
class TestComposableBuffers:
def _get_rb(self, rb_type, size, sampler, writer, storage):

if storage is not None:
Expand Down Expand Up @@ -884,6 +884,67 @@ def test_samplerwithoutrep(size, samples, drop_last):
assert not visited


class TestStateDict:
@pytest.mark.parametrize("storage_in", ["tensor", "memmap"])
@pytest.mark.parametrize("storage_out", ["tensor", "memmap"])
@pytest.mark.parametrize("init_out", [True, False])
def test_load_state_dict(self, storage_in, storage_out, init_out):
buffer_size = 100
if storage_in == "memmap":
storage_in = LazyMemmapStorage(
buffer_size,
device="cpu",
)
elif storage_in == "tensor":
storage_in = LazyTensorStorage(
buffer_size,
device="cpu",
)
if storage_out == "memmap":
storage_out = LazyMemmapStorage(
buffer_size,
device="cpu",
)
elif storage_out == "tensor":
storage_out = LazyTensorStorage(
buffer_size,
device="cpu",
)

replay_buffer = TensorDictReplayBuffer(
pin_memory=False,
prefetch=3,
storage=storage_in,
)
# fill replay buffer with random data
transition = TensorDict(
{
"observation": torch.ones(1, 4),
"action": torch.ones(1, 2),
"reward": torch.ones(1, 1),
"dones": torch.ones(1, 1),
"next": {"observation": torch.ones(1, 4)},
},
batch_size=1,
)
for _ in range(3):
replay_buffer.extend(transition)

state_dict = replay_buffer.state_dict()

new_replay_buffer = TensorDictReplayBuffer(
pin_memory=False,
prefetch=3,
storage=storage_out,
)
if init_out:
new_replay_buffer.extend(transition)

new_replay_buffer.load_state_dict(state_dict)
s = new_replay_buffer.sample(3)
assert (s.exclude("index") == 1).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
4 changes: 2 additions & 2 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,8 @@ class TensorDictReplayBuffer(ReplayBuffer):
within TensorDicts added to this ReplayBuffer.
"""

def __init__(self, priority_key: str = "td_error", **kw) -> None:
super().__init__(**kw)
def __init__(self, *args, priority_key: str = "td_error", **kw) -> None:
super().__init__(*args, **kw)
self.priority_key = priority_key

def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]:
Expand Down
16 changes: 4 additions & 12 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
from tensordict.memmap import MemmapTensor
from tensordict.prototype import is_tensorclass
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase

from torchrl._utils import _CKPT_BACKEND
from torchrl.data.replay_buffers.utils import INT_CLASSES
Expand Down Expand Up @@ -210,11 +210,7 @@ def load_state_dict(self, state_dict):
if isinstance(self._storage, TensorDictBase):
self._storage.load_state_dict(_storage)
elif self._storage is None:
batch_size = _storage.pop("__batch_size")
device = _storage.pop("__device")
self._storage = TensorDict(
_storage, batch_size=batch_size, device=device
)
self._storage = TensorDict({}, []).load_state_dict(_storage)
else:
raise RuntimeError(
f"Cannot copy a storage of type {type(_storage)} onto another of type {type(self._storage)}"
Expand Down Expand Up @@ -333,15 +329,11 @@ def load_state_dict(self, state_dict):
f"Cannot copy a storage of type {type(_storage)} onto another of type {type(self._storage)}"
)
elif isinstance(_storage, (dict, OrderedDict)):
if isinstance(self._storage, TensorDictBase):
if is_tensor_collection(self._storage):
self._storage.load_state_dict(_storage)
self._storage.memmap_()
elif self._storage is None:
batch_size = _storage.pop("__batch_size")
device = _storage.pop("__device")
self._storage = TensorDict(
_storage, batch_size=batch_size, device=device
)
self._storage = TensorDict({}, []).load_state_dict(_storage)
self._storage.memmap_()
else:
raise RuntimeError(
Expand Down

0 comments on commit 501b3af

Please sign in to comment.