Source code for simple_replay_buffer

from typing import Optional
import numpy as np
import random
from collections import deque


[docs]class SimpleReplayBuffer: def __init__(self, capacity: Optional[int] = None):
[docs] self.capacity = capacity
[docs] self.buffer = deque(maxlen=capacity)
[docs] def store(self, state, action, reward, next_state, terminated, truncated): """Store a transition in the buffer.""" self.buffer.append((state, action, reward, next_state, terminated, truncated))
[docs] def extend(self, replay_buffer: "SimpleReplayBuffer"): """Extend the buffer with another buffer.""" self.buffer.extend(replay_buffer.buffer)
[docs] def sample( self, batch_size: int = 1, sequential: bool = False, ): """Sample a batch of experiences.""" if batch_size > self.size(): raise ValueError("Batch size is greater than buffer size.") if sequential: start_idx = random.randint(0, self.size() - batch_size) batch = [self.buffer[i] for i in range(start_idx, start_idx + batch_size)] else: batch = random.sample(self.buffer, batch_size) states, actions, rewards, next_states, terminated, truncated = zip(*batch) states_arr = np.array(states) actions_arr = np.array(actions) rewards_arr = np.array(rewards) next_states_arr = np.array(next_states) terminated_arr = np.array(terminated) truncated_arr = np.array(truncated) if states_arr.ndim == 1: states_arr = states_arr.reshape(-1, 1) if next_states_arr.ndim == 1: next_states_arr = next_states_arr.reshape(-1, 1) return ( states_arr, actions_arr, rewards_arr, next_states_arr, terminated_arr, truncated_arr, )
[docs] def size(self): """Return current buffer size.""" return len(self.buffer)
[docs] def clear(self): """Clear the buffer.""" self.buffer.clear()
[docs] def __len__(self): """Return the current buffer size.""" return self.size()
[docs] def pop(self, num_elements: int = 1): """Pop elements from the buffer.""" if num_elements > self.size(): raise ValueError("Number of elements to pop is greater than buffer size.") return [self.buffer.popleft() for _ in range(num_elements)]
if __name__ == "__main__":
[docs] buffer = SimpleReplayBuffer(None)
buffer.store(np.array([0, 0, 0]), 0, 1, np.array([1, 1, 1]), False, False) buffer.store(np.array([2, 2, 2]), 1, 0.5, np.array([3, 3, 3]), True, False) batch = buffer.sample(2) print(batch) print(buffer.size()) batch_2 = buffer.sample(2, sequential=True) print(batch_2) buffer.clear() print(buffer.size())