-
Notifications
You must be signed in to change notification settings - Fork 4
/
cartpole_dqn.py
205 lines (163 loc) · 6.76 KB
/
cartpole_dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import sys
import gym
import torch
import pylab
import random
import numpy as np
from collections import deque
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import transforms
EPISODES = 500
# approximate Q function using Neural Network
# state is input and Q Value of each action is output of network
class DQN(nn.Module):
def __init__(self, state_size, action_size):
super(DQN, self).__init__()
self.fc = nn.Sequential(
nn.Linear(state_size, 24),
nn.ReLU(),
nn.Linear(24, 24),
nn.ReLU(),
nn.Linear(24, action_size)
)
def forward(self, x):
return self.fc(x)
# DQN Agent for the Cartpole
# it uses Neural Network to approximate q function
# and replay memory & target q network
class DQNAgent():
def __init__(self, state_size, action_size):
# if you want to see Cartpole learning, then change to True
self.render = False
self.load_model = False
# get size of state and action
self.state_size = state_size
self.action_size = action_size
# These are hyper parameters for the DQN
self.discount_factor = 0.99
self.learning_rate = 0.001
self.memory_size = 20000
self.epsilon = 1.0
self.epsilon_min = 0.01
self.explore_step = 5000
self.epsilon_decay = (self.epsilon - self.epsilon_min) / self.explore_step
self.batch_size = 64
self.train_start = 1000
# create replay memory using deque
self.memory = deque(maxlen=self.memory_size)
# create main model and target model
self.model = DQN(state_size, action_size)
self.model.apply(self.weights_init)
self.target_model = DQN(state_size, action_size)
self.optimizer = optim.Adam(self.model.parameters(),
lr=self.learning_rate)
# initialize target model
self.update_target_model()
if self.load_model:
self.model = torch.load('save_model/cartpole_dqn')
# weight xavier initialize
def weights_init(self, m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
torch.nn.init.xavier_uniform(m.weight)
# after some time interval update the target model to be same with model
def update_target_model(self):
self.target_model.load_state_dict(self.model.state_dict())
# get action from model using epsilon-greedy policy
def get_action(self, state):
if np.random.rand() <= self.epsilon:
return random.randrange(self.action_size)
else:
state = torch.from_numpy(state)
state = Variable(state).float().cpu()
q_value = self.model(state)
_, action = torch.max(q_value, 1)
return int(action)
# save sample <s,a,r,s'> to the replay memory
def append_sample(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
# pick samples randomly from replay memory (with batch_size)
def train_model(self):
if self.epsilon > self.epsilon_min:
self.epsilon -= self.epsilon_decay
mini_batch = random.sample(self.memory, self.batch_size)
mini_batch = np.array(mini_batch).transpose()
states = np.vstack(mini_batch[0])
actions = list(mini_batch[1])
rewards = list(mini_batch[2])
next_states = np.vstack(mini_batch[3])
dones = mini_batch[4]
# bool to binary
dones = dones.astype(int)
# Q function of current state
states = torch.Tensor(states)
states = Variable(states).float()
pred = self.model(states)
# one-hot encoding
a = torch.LongTensor(actions).view(-1, 1)
one_hot_action = torch.FloatTensor(self.batch_size, self.action_size).zero_()
one_hot_action.scatter_(1, a, 1)
pred = torch.sum(pred.mul(Variable(one_hot_action)), dim=1)
# Q function of next state
next_states = torch.Tensor(next_states)
next_states = Variable(next_states).float()
next_pred = self.target_model(next_states).data
rewards = torch.FloatTensor(rewards)
dones = torch.FloatTensor(dones)
# Q Learning: get maximum Q value at s' from target model
target = rewards + (1 - dones) * self.discount_factor * next_pred.max(1)[0]
target = Variable(target)
self.optimizer.zero_grad()
# MSE Loss function
loss = F.mse_loss(pred,target)
loss.backward()
# and train
self.optimizer.step()
if __name__ == "__main__":
# In case of CartPole-v1, maximum length of episode is 500
env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
model = DQN(state_size, action_size)
agent = DQNAgent(state_size, action_size)
scores, episodes = [], []
for e in range(EPISODES):
done = False
score = 0
state = env.reset()
state = np.reshape(state, [1, state_size])
while not done:
if agent.render:
env.render()
# get action for the current state and go one step in environment
action = agent.get_action(state)
next_state, reward, done, info = env.step(action)
next_state = np.reshape(next_state, [1, state_size])
# if an action make the episode end, then gives penalty of -100
reward = reward if not done or score == 499 else -10
# save the sample <s, a, r, s'> to the replay memory
agent.append_sample(state, action, reward, next_state, done)
# every time step do the training
if len(agent.memory) >= agent.train_start:
agent.train_model()
score += reward
state = next_state
if done:
# every episode update the target model to be same with model
agent.update_target_model()
# every episode, plot the play time
score = score if score == 500 else score + 10
scores.append(score)
episodes.append(e)
pylab.plot(episodes, scores, 'b')
pylab.savefig("./save_graph/cartpole_dqn.png")
print("episode:", e, " score:", score, " memory length:",
len(agent.memory), " epsilon:", agent.epsilon)
# if the mean of scores of last 10 episode is bigger than 490
# stop training
if np.mean(scores[-min(10, len(scores)):]) > 490:
torch.save(agent.model, "./save_model/cartpole_dqn")
sys.exit()