From 4f76891894dd4d19334ae6f75372791940b4b4d9 Mon Sep 17 00:00:00 2001 From: "Dr. Juan Rojas" Date: Tue, 10 Mar 2026 21:03:28 -0500 Subject: [PATCH] Use pixel frame pointers in memory efficient replay buffer --- .../data/memory_efficient_replay_buffer.py | 58 ++++++++++++------- 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/serl_launcher/serl_launcher/data/memory_efficient_replay_buffer.py b/serl_launcher/serl_launcher/data/memory_efficient_replay_buffer.py index d94f1143..079cfe66 100644 --- a/serl_launcher/serl_launcher/data/memory_efficient_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/memory_efficient_replay_buffer.py @@ -21,19 +21,27 @@ def __init__( observation_space = copy.deepcopy(observation_space) self._num_stack = None + self._pixel_buffers = {} + self._pixel_buffer_size = capacity + 1 + self._pixel_buffer_insert_index = 0 for pixel_key in self.pixel_keys: pixel_obs_space = observation_space.spaces[pixel_key] if self._num_stack is None: self._num_stack = pixel_obs_space.shape[0] else: assert self._num_stack == pixel_obs_space.shape[0] - self._unstacked_dim_size = pixel_obs_space.shape[-1] - low = pixel_obs_space.low[0] - high = pixel_obs_space.high[0] - unstacked_pixel_obs_space = Box( - low=low, high=high, dtype=pixel_obs_space.dtype + self._pixel_buffer_size = capacity + self._num_stack + self._pixel_buffers[pixel_key] = np.empty( + (self._pixel_buffer_size, *pixel_obs_space.shape[1:]), + dtype=pixel_obs_space.dtype, ) - observation_space.spaces[pixel_key] = unstacked_pixel_obs_space + pointer_space = Box( + low=np.array(0, dtype=np.int32), + high=np.array(self._pixel_buffer_size - 1, dtype=np.int32), + shape=(), + dtype=np.int32, + ) + observation_space.spaces[pixel_key] = pointer_space next_observation_space_dict = copy.deepcopy(observation_space.spaces) for pixel_key in self.pixel_keys: @@ -50,14 +58,17 @@ def __init__( next_observation_space=next_observation_space, ) - def insert(self, data_dict: DatasetDict): - if self._insert_index == 0 and self._capacity == len(self) and not self._first: - indxs = np.arange(len(self) - self._num_stack, len(self)) - for indx in indxs: - element = super().sample(1, indx=indx) - self._is_correct_index[self._insert_index] = False - super().insert(element) + def _insert_pixel_frame(self, pixel_key: str, frame: np.ndarray) -> np.int32: + pointer = np.int32(self._pixel_buffer_insert_index) + self._pixel_buffers[pixel_key][pointer] = frame + return pointer + + def _advance_pixel_pointer(self): + self._pixel_buffer_insert_index = ( + self._pixel_buffer_insert_index + 1 + ) % self._pixel_buffer_size + def insert(self, data_dict: DatasetDict): data_dict = data_dict.copy() data_dict["observations"] = data_dict["observations"].copy() data_dict["next_observations"] = data_dict["next_observations"].copy() @@ -71,18 +82,24 @@ def insert(self, data_dict: DatasetDict): if self._first: for i in range(self._num_stack): for pixel_key in self.pixel_keys: - data_dict["observations"][pixel_key] = obs_pixels[pixel_key][i] + data_dict["observations"][pixel_key] = self._insert_pixel_frame( + pixel_key, obs_pixels[pixel_key][i] + ) self._is_correct_index[self._insert_index] = False super().insert(data_dict) + self._advance_pixel_pointer() for pixel_key in self.pixel_keys: - data_dict["observations"][pixel_key] = next_obs_pixels[pixel_key][-1] + data_dict["observations"][pixel_key] = self._insert_pixel_frame( + pixel_key, next_obs_pixels[pixel_key][-1] + ) self._first = data_dict["dones"] self._is_correct_index[self._insert_index] = True super().insert(data_dict) + self._advance_pixel_pointer() for i in range(self._num_stack): indx = (self._insert_index + i) % len(self) @@ -146,13 +163,10 @@ def sample( ) for pixel_key in self.pixel_keys: - obs_pixels = self.dataset_dict["observations"][pixel_key] - obs_pixels = np.lib.stride_tricks.sliding_window_view( - obs_pixels, self._num_stack + 1, axis=0 - ) - obs_pixels = obs_pixels[indx - self._num_stack] - # transpose from (B, H, W, C, T) to (B, T, H, W, C) to follow jaxrl_m convention - obs_pixels = obs_pixels.transpose((0, 4, 1, 2, 3)) + history_offsets = np.arange(self._num_stack, -1, -1) + history_indx = (indx[:, None] - history_offsets[None, :]) % len(self) + pixel_pointers = self.dataset_dict["observations"][pixel_key][history_indx] + obs_pixels = self._pixel_buffers[pixel_key][pixel_pointers] if pack_obs_and_next_obs: batch["observations"][pixel_key] = obs_pixels