forked from riiswa/kanrl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
buffer.py
37 lines (30 loc) · 1.21 KB
/
buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import numpy as np
class ReplayBuffer:
def __init__(self, capacity, observation_dim):
self.capacity = capacity
self.observations = torch.zeros(capacity, observation_dim)
self.actions = torch.zeros(capacity, 1, dtype=torch.int64)
self.next_observations = torch.zeros(capacity, observation_dim)
self.rewards = torch.zeros(capacity, 1)
self.terminations = torch.zeros(capacity, 1, dtype=torch.int)
self.cursor = 0
def add(self, observation, action, next_observation, reward, termination):
index = self.cursor % self.capacity
self.observations[index] = observation
self.actions[index] = action
self.next_observations[index] = next_observation
self.rewards[index] = reward
self.terminations[index] = termination
self.cursor += 1
def sample(self, batch_size):
idx = np.random.permutation(np.arange(len(self)))[:batch_size]
return (
self.observations[idx],
self.actions[idx],
self.next_observations[idx],
self.rewards[idx],
self.terminations[idx],
)
def __len__(self):
return min(self.cursor, self.capacity)