Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 36 additions & 22 deletions serl_launcher/serl_launcher/data/memory_efficient_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,27 @@ def __init__(

observation_space = copy.deepcopy(observation_space)
self._num_stack = None
self._pixel_buffers = {}
self._pixel_buffer_size = capacity + 1
self._pixel_buffer_insert_index = 0
for pixel_key in self.pixel_keys:
pixel_obs_space = observation_space.spaces[pixel_key]
if self._num_stack is None:
self._num_stack = pixel_obs_space.shape[0]
else:
assert self._num_stack == pixel_obs_space.shape[0]
self._unstacked_dim_size = pixel_obs_space.shape[-1]
low = pixel_obs_space.low[0]
high = pixel_obs_space.high[0]
unstacked_pixel_obs_space = Box(
low=low, high=high, dtype=pixel_obs_space.dtype
self._pixel_buffer_size = capacity + self._num_stack
self._pixel_buffers[pixel_key] = np.empty(
(self._pixel_buffer_size, *pixel_obs_space.shape[1:]),
dtype=pixel_obs_space.dtype,
)
observation_space.spaces[pixel_key] = unstacked_pixel_obs_space
pointer_space = Box(
low=np.array(0, dtype=np.int32),
high=np.array(self._pixel_buffer_size - 1, dtype=np.int32),
shape=(),
dtype=np.int32,
)
observation_space.spaces[pixel_key] = pointer_space

next_observation_space_dict = copy.deepcopy(observation_space.spaces)
for pixel_key in self.pixel_keys:
Expand All @@ -50,14 +58,17 @@ def __init__(
next_observation_space=next_observation_space,
)

def insert(self, data_dict: DatasetDict):
if self._insert_index == 0 and self._capacity == len(self) and not self._first:
indxs = np.arange(len(self) - self._num_stack, len(self))
for indx in indxs:
element = super().sample(1, indx=indx)
self._is_correct_index[self._insert_index] = False
super().insert(element)
def _insert_pixel_frame(self, pixel_key: str, frame: np.ndarray) -> np.int32:
pointer = np.int32(self._pixel_buffer_insert_index)
self._pixel_buffers[pixel_key][pointer] = frame
return pointer

def _advance_pixel_pointer(self):
self._pixel_buffer_insert_index = (
self._pixel_buffer_insert_index + 1
) % self._pixel_buffer_size

def insert(self, data_dict: DatasetDict):
data_dict = data_dict.copy()
data_dict["observations"] = data_dict["observations"].copy()
data_dict["next_observations"] = data_dict["next_observations"].copy()
Expand All @@ -71,18 +82,24 @@ def insert(self, data_dict: DatasetDict):
if self._first:
for i in range(self._num_stack):
for pixel_key in self.pixel_keys:
data_dict["observations"][pixel_key] = obs_pixels[pixel_key][i]
data_dict["observations"][pixel_key] = self._insert_pixel_frame(
pixel_key, obs_pixels[pixel_key][i]
)

self._is_correct_index[self._insert_index] = False
super().insert(data_dict)
self._advance_pixel_pointer()

for pixel_key in self.pixel_keys:
data_dict["observations"][pixel_key] = next_obs_pixels[pixel_key][-1]
data_dict["observations"][pixel_key] = self._insert_pixel_frame(
pixel_key, next_obs_pixels[pixel_key][-1]
)

self._first = data_dict["dones"]

self._is_correct_index[self._insert_index] = True
super().insert(data_dict)
self._advance_pixel_pointer()

for i in range(self._num_stack):
indx = (self._insert_index + i) % len(self)
Expand Down Expand Up @@ -146,13 +163,10 @@ def sample(
)

for pixel_key in self.pixel_keys:
obs_pixels = self.dataset_dict["observations"][pixel_key]
obs_pixels = np.lib.stride_tricks.sliding_window_view(
obs_pixels, self._num_stack + 1, axis=0
)
obs_pixels = obs_pixels[indx - self._num_stack]
# transpose from (B, H, W, C, T) to (B, T, H, W, C) to follow jaxrl_m convention
obs_pixels = obs_pixels.transpose((0, 4, 1, 2, 3))
history_offsets = np.arange(self._num_stack, -1, -1)
history_indx = (indx[:, None] - history_offsets[None, :]) % len(self)
pixel_pointers = self.dataset_dict["observations"][pixel_key][history_indx]
obs_pixels = self._pixel_buffers[pixel_key][pixel_pointers]

if pack_obs_and_next_obs:
batch["observations"][pixel_key] = obs_pixels
Expand Down